From 0b21f8a6bdc717f123bd040e99e07696e0d7cb00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 8 Sep 2025 19:57:44 -0700 Subject: [PATCH] FrameProcessor: add push_interruption_task_frame_and_wait() --- CHANGELOG.md | 11 +++ .../voicemail/voicemail_detector.py | 3 +- src/pipecat/pipeline/task.py | 7 +- .../processors/aggregators/dtmf_aggregator.py | 3 +- .../processors/aggregators/llm_response.py | 6 +- .../aggregators/llm_response_universal.py | 5 +- src/pipecat/processors/frame_processor.py | 71 ++++++++++++++++--- src/pipecat/processors/frameworks/rtvi.py | 3 +- .../services/openai_realtime/openai.py | 4 +- .../services/openai_realtime_beta/openai.py | 4 +- src/pipecat/services/speechmatics/stt.py | 8 +-- src/pipecat/transports/base_input.py | 11 +-- tests/test_dtmf_aggregator.py | 15 +++- 13 files changed, 103 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f17c11ab0..7524b0d53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `FrameProcessor.push_interruption_task_frame_and_wait()`. Use this + method to programatically interrupt the bot from any part of the + pipeline. This guarantees that all the processors in the pipeline are + interrupted in order (from upstream to downstream). Internally, this works by + first pushing an `InterruptionTaskFrame` upstream until it reaches the + pipeline task. The pipeline task then generates an `InterruptionFrame`, which + flows downstream through all processors. Once the `InterruptionFrame` has + reaches the processor waiting for the interruption, the function returns and + execution continues after the call. Think of it as sending an upstream request + for interruption and waiting until the acknowledgment flows back downstream. + - Added new base `TaskFrame` (which is a system frame). This is the base class for all task frames (`EndTaskFrame`, `CancelTaskFrame`, etc.) that are meant to be pushed upstream to reach the pipeline task. diff --git a/src/pipecat/extensions/voicemail/voicemail_detector.py b/src/pipecat/extensions/voicemail/voicemail_detector.py index 1b5e5ac62..1b460404b 100644 --- a/src/pipecat/extensions/voicemail/voicemail_detector.py +++ b/src/pipecat/extensions/voicemail/voicemail_detector.py @@ -23,7 +23,6 @@ from loguru import logger from pipecat.frames.frames import ( EndFrame, Frame, - InterruptionTaskFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMTextFrame, @@ -360,7 +359,7 @@ class ClassificationProcessor(FrameProcessor): await self._voicemail_notifier.notify() # Clear buffered TTS frames # Interrupt the current pipeline to stop any ongoing processing - await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM) + await self.push_interruption_task_frame_and_wait() # Set the voicemail event to trigger the voicemail handler self._voicemail_event.clear() diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 975c70cb9..8eec80ce4 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -640,9 +640,12 @@ class PipelineTask(BasePipelineTask): logger.debug(f"{self}: received stop task frame {frame}") await self.queue_frame(StopFrame()) elif isinstance(frame, InterruptionTaskFrame): - # Tell the task we should interrupt the pipeline. + # Tell the task we should interrupt the pipeline. Note that we are + # bypassing the push queue and directly queue into the + # pipeline. This is in case the push task is blocked waiting for a + # pipeline-ending frame to finish traversing the pipeline. logger.debug(f"{self}: received interruption task frame {frame}") - await self.queue_frame(InterruptionFrame()) + await self._pipeline.queue_frame(InterruptionFrame()) elif isinstance(frame, ErrorFrame): if frame.fatal: logger.error(f"A fatal error occurred: {frame}") diff --git a/src/pipecat/processors/aggregators/dtmf_aggregator.py b/src/pipecat/processors/aggregators/dtmf_aggregator.py index b8b947272..1aa0760b4 100644 --- a/src/pipecat/processors/aggregators/dtmf_aggregator.py +++ b/src/pipecat/processors/aggregators/dtmf_aggregator.py @@ -20,7 +20,6 @@ from pipecat.frames.frames import ( EndFrame, Frame, InputDTMFFrame, - InterruptionTaskFrame, StartFrame, TranscriptionFrame, ) @@ -105,7 +104,7 @@ class DTMFAggregator(FrameProcessor): # For first digit, schedule interruption. if is_first_digit: - await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM) + await self.push_interruption_task_frame_and_wait() # Check for immediate flush conditions if frame.button == self._termination_digit: diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index d38233275..ace7b94fd 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -35,7 +35,7 @@ from pipecat.frames.frames import ( FunctionCallsStartedFrame, InputAudioRawFrame, InterimTranscriptionFrame, - InterruptionTaskFrame, + InterruptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesAppendFrame, @@ -531,9 +531,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): if should_interrupt: logger.debug( - "Interruption conditions met - pushing InterruptionTaskFrame and aggregation" + "Interruption conditions met - pushing interruption and aggregation" ) - await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM) + await self.push_interruption_task_frame_and_wait() await self._process_aggregation() else: logger.debug("Interruption conditions not met - not pushing aggregation") diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index 70e4b46d0..7cb101fa1 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -35,7 +35,6 @@ from pipecat.frames.frames import ( FunctionCallsStartedFrame, InputAudioRawFrame, InterimTranscriptionFrame, - InterruptionTaskFrame, LLMContextAssistantTimestampFrame, LLMContextFrame, LLMFullResponseEndFrame, @@ -309,9 +308,9 @@ class LLMUserAggregator(LLMContextAggregator): if should_interrupt: logger.debug( - "Interruption conditions met - pushing InterruptionTaskFrame and aggregation" + "Interruption conditions met - pushing interruption and aggregation" ) - await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM) + await self.push_interruption_task_frame_and_wait() await self._process_aggregation() else: logger.debug("Interruption conditions not met - not pushing aggregation") diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 0b30204d4..76400fc22 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -29,6 +29,7 @@ from pipecat.frames.frames import ( FrameProcessorResumeFrame, FrameProcessorResumeUrgentFrame, InterruptionFrame, + InterruptionTaskFrame, StartFrame, SystemFrame, ) @@ -219,6 +220,9 @@ class FrameProcessor(BaseObject): self.__process_event: Optional[asyncio.Event] = None self.__process_frame_task: Optional[asyncio.Task] = None + self._wait_for_interruption = False + self._wait_interruption_event = asyncio.Event() + @property def id(self) -> int: """Get the unique identifier for this processor. @@ -542,6 +546,14 @@ class FrameProcessor(BaseObject): if self._cancelling: return + # If we are waiting for an interruption we will bypass all queued system + # frames and we will process the frame right away. This is because a + # previous system frame might be waiting for the interruption frame and + # it's blocking the input task. + if self._wait_for_interruption and isinstance(frame, InterruptionFrame): + await self.__process_frame(frame, direction, callback) + return + if self._enable_direct_mode: await self.__process_frame(frame, direction, callback) else: @@ -620,6 +632,32 @@ class FrameProcessor(BaseObject): await self.__internal_push_frame(frame, direction) + if isinstance(frame, InterruptionFrame): + self._wait_interruption_event.set() + + async def push_interruption_task_frame_and_wait(self): + """Push an interruption task frame upstream and wait for the interruption. + + This function sends an `InterruptionTaskFrame` upstream to the pipeline + task and waits to receive the corresponding `InterruptionFrame`. When + the function finishes it is guaranteed that the `InterruptionFrame` has + been pushed downstream. + """ + self._wait_for_interruption = True + + await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM) + + # Wait for an `InterruptionFrame` to come to this processor and be + # pushed. Take a look at `push_frame()` to see how we first push the + # `InterruptionFrame` and then we set the event in order to maintain + # frame ordering. + await self._wait_interruption_event.wait() + + # Clean the event. + self._wait_interruption_event.clear() + + self._wait_for_interruption = False + async def __start(self, frame: StartFrame): """Handle the start frame to initialize processor state. @@ -669,20 +707,22 @@ class FrameProcessor(BaseObject): async def _start_interruption(self): """Start handling an interruption by cancelling current tasks.""" try: - # Cancel the process task. This will stop processing queued frames. - await self.__cancel_process_task() + if self._wait_for_interruption: + # If we get here we know the process task was just waiting for + # an interruption (push_interruption_task_frame_and_wait()), so + # we can't cancel the task because it might still need to do + # more things (e.g. pushing a frame after the + # interruption). Instead we just drain the queue because this is + # an interruption. + self.__reset_process_task() + else: + # Cancel and re-create the process task including the queue. + await self.__cancel_process_task() + self.__create_process_task() except Exception as e: logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}") await self.push_error(ErrorFrame(str(e))) - # Create a new process queue and task. - self.__create_process_task() - - async def _stop_interruption(self): - """Stop handling an interruption.""" - # Nothing to do right now. - pass - async def __internal_push_frame(self, frame: Frame, direction: FrameDirection): """Internal method to push frames to adjacent processors. @@ -764,6 +804,17 @@ class FrameProcessor(BaseObject): self.__process_queue = asyncio.Queue() self.__process_frame_task = self.create_task(self.__process_frame_task_handler()) + def __reset_process_task(self): + """Reset non-system frame processing task.""" + if self._enable_direct_mode: + return + + self.__should_block_frames = False + self.__process_event = asyncio.Event() + while not self.__process_queue.empty(): + self.__process_queue.get_nowait() + self.__process_queue.task_done() + async def __cancel_process_task(self): """Cancel the non-system frame processing task.""" if self.__process_frame_task: diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index bd63cd0c2..f469c3431 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -41,7 +41,6 @@ from pipecat.frames.frames import ( FunctionCallResultFrame, InputAudioRawFrame, InterimTranscriptionFrame, - InterruptionTaskFrame, LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -1206,7 +1205,7 @@ class RTVIProcessor(FrameProcessor): async def interrupt_bot(self): """Send a bot interruption frame upstream.""" - await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM) + await self.push_interruption_task_frame_and_wait() async def send_server_message(self, data: Any): """Send a server message to the client.""" diff --git a/src/pipecat/services/openai_realtime/openai.py b/src/pipecat/services/openai_realtime/openai.py index 67fd18308..8b3d500eb 100644 --- a/src/pipecat/services/openai_realtime/openai.py +++ b/src/pipecat/services/openai_realtime/openai.py @@ -716,14 +716,12 @@ class OpenAIRealtimeLLMService(LLMService): async def _handle_evt_speech_started(self, evt): await self._truncate_current_audio_response() - await self._start_interruption() # cancels this processor task - await self.push_frame(InterruptionFrame()) # cancels downstream tasks + await self.push_interruption_task_frame_and_wait() await self.push_frame(UserStartedSpeakingFrame()) async def _handle_evt_speech_stopped(self, evt): await self.start_ttfb_metrics() await self.start_processing_metrics() - await self._stop_interruption() await self.push_frame(UserStoppedSpeakingFrame()) async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent): diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index 84bcd039b..922f9a572 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -658,14 +658,12 @@ class OpenAIRealtimeBetaLLMService(LLMService): async def _handle_evt_speech_started(self, evt): await self._truncate_current_audio_response() - await self._start_interruption() # cancels this processor task - await self.push_frame(InterruptionFrame()) # cancels downstream tasks + await self.push_interruption_task_frame_and_wait() await self.push_frame(UserStartedSpeakingFrame()) async def _handle_evt_speech_stopped(self, evt): await self.start_ttfb_metrics() await self.start_processing_metrics() - await self._stop_interruption() await self.push_frame(UserStoppedSpeakingFrame()) async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent): diff --git a/src/pipecat/services/speechmatics/stt.py b/src/pipecat/services/speechmatics/stt.py index 3cd912093..4028dd248 100644 --- a/src/pipecat/services/speechmatics/stt.py +++ b/src/pipecat/services/speechmatics/stt.py @@ -24,7 +24,6 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, InterimTranscriptionFrame, - InterruptionTaskFrame, StartFrame, TranscriptionFrame, UserStartedSpeakingFrame, @@ -749,14 +748,13 @@ class SpeechmaticsSTTService(STTService): return # Frames to send - upstream_frames: list[Frame] = [] downstream_frames: list[Frame] = [] # If VAD is enabled, then send a speaking frame if self._params.enable_vad and not self._is_speaking: logger.debug("User started speaking") self._is_speaking = True - upstream_frames += [InterruptionTaskFrame()] + await self.push_interruption_task_frame_and_wait() downstream_frames += [UserStartedSpeakingFrame()] # If final, then re-parse into TranscriptionFrame @@ -794,10 +792,6 @@ class SpeechmaticsSTTService(STTService): self._is_speaking = False downstream_frames += [UserStoppedSpeakingFrame()] - # Send UPSTREAM frames - for frame in upstream_frames: - await self.push_frame(frame, FrameDirection.UPSTREAM) - # Send the DOWNSTREAM frames for frame in downstream_frames: await self.push_frame(frame, FrameDirection.DOWNSTREAM) diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 04f74e4f8..f2ccd1d96 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -32,8 +32,6 @@ from pipecat.frames.frames import ( Frame, InputAudioRawFrame, InputImageRawFrame, - InterruptionFrame, - InterruptionTaskFrame, MetricsFrame, SpeechControlParamsFrame, StartFrame, @@ -353,11 +351,7 @@ class BaseInputTransport(FrameProcessor): # Make sure we notify about interruptions quickly out-of-band. if should_push_immediate_interruption and self.interruptions_allowed: - await self._start_interruption() - # Push an out-of-band frame (i.e. not using the ordered push - # frame task) to stop everything, specially at the output - # transport. - await self.push_frame(InterruptionFrame()) + await self.push_interruption_task_frame_and_wait() elif self.interruption_strategies and self._bot_speaking: logger.debug( "User started speaking while bot is speaking with interruption config - " @@ -372,9 +366,6 @@ class BaseInputTransport(FrameProcessor): await self.push_frame(downstream_frame) await self.push_frame(upstream_frame, FrameDirection.UPSTREAM) - if self.interruptions_allowed: - await self._stop_interruption() - # # Handle bot speaking state # diff --git a/tests/test_dtmf_aggregator.py b/tests/test_dtmf_aggregator.py index 5d9d1346a..c7590ae47 100644 --- a/tests/test_dtmf_aggregator.py +++ b/tests/test_dtmf_aggregator.py @@ -10,6 +10,7 @@ from pipecat.audio.dtmf.types import KeypadEntry from pipecat.frames.frames import ( EndFrame, InputDTMFFrame, + InterruptionFrame, TranscriptionFrame, ) from pipecat.processors.aggregators.dtmf_aggregator import DTMFAggregator @@ -28,6 +29,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase): ] expected_down_frames = [ InputDTMFFrame, + InterruptionFrame, InputDTMFFrame, InputDTMFFrame, InputDTMFFrame, @@ -59,9 +61,11 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase): ] expected_down_frames = [ InputDTMFFrame, + InterruptionFrame, InputDTMFFrame, TranscriptionFrame, # First aggregation "12" InputDTMFFrame, + InterruptionFrame, TranscriptionFrame, # Second aggregation "3" ] @@ -93,10 +97,12 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase): ] expected_down_frames = [ InputDTMFFrame, + InterruptionFrame, InputDTMFFrame, InputDTMFFrame, TranscriptionFrame, # "12#" InputDTMFFrame, + InterruptionFrame, InputDTMFFrame, TranscriptionFrame, # "45" ] @@ -125,6 +131,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase): ] expected_down_frames = [ InputDTMFFrame, + InterruptionFrame, InputDTMFFrame, TranscriptionFrame, # Should flush before EndFrame EndFrame, @@ -152,6 +159,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase): ] expected_down_frames = [ InputDTMFFrame, + InterruptionFrame, InputDTMFFrame, TranscriptionFrame, ] @@ -178,6 +186,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase): ] expected_down_frames = [ InputDTMFFrame, + InterruptionFrame, InputDTMFFrame, InputDTMFFrame, TranscriptionFrame, @@ -214,7 +223,11 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase): ] # All the InputDTMFFrames plus one TranscriptionFrame - expected_down_frames = [InputDTMFFrame] * len(frames_to_send) + [TranscriptionFrame] + expected_down_frames = ( + [InputDTMFFrame, InterruptionFrame] + + [InputDTMFFrame] * (len(frames_to_send) - 1) + + [TranscriptionFrame] + ) received_down_frames, _ = await run_test( aggregator,