diff --git a/CHANGELOG.md b/CHANGELOG.md index 506a9305e..7584ca409 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -74,6 +74,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed an issue that could cause the `TranscriptionUpdateFrame` being pushed + because of an interruption to be discarded. + - Fixed an issue that would cause `SegmentedSTTService` based services (e.g. `OpenAISTTService`) to try to transcribe non-spoken audio, causing invalid transcriptions. diff --git a/src/pipecat/processors/transcript_processor.py b/src/pipecat/processors/transcript_processor.py index 3eaff66ca..a2ad22223 100644 --- a/src/pipecat/processors/transcript_processor.py +++ b/src/pipecat/processors/transcript_processor.py @@ -175,22 +175,28 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor): """ await super().process_frame(frame, direction) - if isinstance(frame, TTSTextFrame): + if isinstance(frame, (StartInterruptionFrame, CancelFrame)): + # Push frame first otherwise our emitted transcription update frame + # might get cleaned up. + await self.push_frame(frame, direction) + # Emit accumulated text with interruptions + await self._emit_aggregated_text() + elif isinstance(frame, TTSTextFrame): # Start timestamp on first text part if not self._aggregation_start_time: self._aggregation_start_time = time_now_iso8601() self._current_text_parts.append(frame.text) - elif isinstance(frame, (BotStoppedSpeakingFrame, StartInterruptionFrame, CancelFrame)): - # Emit accumulated text when bot finishes speaking or is interrupted + # Push frame. + await self.push_frame(frame, direction) + elif isinstance(frame, (BotStoppedSpeakingFrame, EndFrame)): + # Emit accumulated text when bot finishes speaking or pipeline ends. await self._emit_aggregated_text() - - elif isinstance(frame, EndFrame): - # Emit any remaining text when pipeline ends - await self._emit_aggregated_text() - - await self.push_frame(frame, direction) + # Push frame. + await self.push_frame(frame, direction) + else: + await self.push_frame(frame, direction) class TranscriptProcessor: diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index b735ed08e..05734a64e 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -41,7 +41,10 @@ from pipecat.services.google.llm import ( GoogleLLMContext, GoogleUserContextAggregator, ) -from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator +from pipecat.services.openai.llm import ( + OpenAIAssistantContextAggregator, + OpenAIUserContextAggregator, +) from pipecat.tests.utils import SleepFrame, run_test AGGREGATION_TIMEOUT = 0.1 diff --git a/tests/test_transcript_processor.py b/tests/test_transcript_processor.py index d13246b2c..601631a8e 100644 --- a/tests/test_transcript_processor.py +++ b/tests/test_transcript_processor.py @@ -238,8 +238,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): TTSTextFrame(text="world!"), SleepFrame(sleep=0.1), StartInterruptionFrame(), # User interrupts here - BotStartedSpeakingFrame(), SleepFrame(sleep=0.1), + BotStartedSpeakingFrame(), TTSTextFrame(text="New"), TTSTextFrame(text="response"), SleepFrame(sleep=0.1), @@ -251,8 +251,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): BotStartedSpeakingFrame, TTSTextFrame, # "Hello" TTSTextFrame, # "world!" + StartInterruptionFrame, TranscriptionUpdateFrame, # First message (emitted due to interruption) - StartInterruptionFrame, # Interruption frame comes after the update BotStartedSpeakingFrame, TTSTextFrame, # "New" TTSTextFrame, # "response"