From 4a61d5bfadc55651e921be8bb1ac40affbacd929 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 2 Mar 2026 12:04:51 -0800 Subject: [PATCH] Add broadcast_interruption() to FrameProcessor Replace the round-trip push_interruption_task_frame_and_wait() mechanism with broadcast_interruption(), which pushes an InterruptionFrame both upstream and downstream directly from the calling processor. This eliminates race conditions (transcription arriving before the InterruptionFrame comes back), swallowed-event timeouts (frame blocked before reaching the sink), and the complexity of _wait_for_interruption flag / queue bypass / frame.complete() obligations. - Add broadcast_interruption() to FrameProcessor - Deprecate push_interruption_task_frame_and_wait() (delegates to new method) - Remove event field and complete() from InterruptionFrame/InterruptionTaskFrame - Remove _wait_for_interruption flag and all special-case logic - Remove frame.complete() calls in stt_mute_filter and llm_response_universal - Update all 17 call sites to use broadcast_interruption() - Update tests --- changelog/3900.added.md | 1 + changelog/3900.changed.md | 1 + changelog/3900.deprecated.md | 1 + .../voicemail/voicemail_detector.py | 2 +- src/pipecat/frames/frames.py | 29 +---- src/pipecat/pipeline/task.py | 4 +- .../processors/aggregators/dtmf_aggregator.py | 2 +- .../processors/aggregators/llm_response.py | 2 +- .../aggregators/llm_response_universal.py | 8 +- .../processors/filters/stt_mute_filter.py | 6 - src/pipecat/processors/frame_processor.py | 77 ++++-------- src/pipecat/processors/frameworks/rtvi.py | 2 +- src/pipecat/services/deepgram/flux/stt.py | 2 +- src/pipecat/services/deepgram/stt.py | 2 +- src/pipecat/services/gladia/stt.py | 2 +- .../services/google/gemini_live/llm.py | 2 +- src/pipecat/services/grok/realtime/llm.py | 2 +- src/pipecat/services/openai/realtime/llm.py | 2 +- src/pipecat/services/openai/stt.py | 2 +- .../services/openai_realtime_beta/openai.py | 2 +- src/pipecat/services/sarvam/stt.py | 2 +- src/pipecat/services/speechmatics/stt.py | 2 +- src/pipecat/transports/base_input.py | 2 +- src/pipecat/turns/user_turn_processor.py | 2 +- tests/test_context_aggregators.py | 3 +- tests/test_frame_processor.py | 117 +++--------------- tests/test_stt_mute_filter.py | 13 +- 27 files changed, 68 insertions(+), 224 deletions(-) create mode 100644 changelog/3900.added.md create mode 100644 changelog/3900.changed.md create mode 100644 changelog/3900.deprecated.md diff --git a/changelog/3900.added.md b/changelog/3900.added.md new file mode 100644 index 000000000..08921c004 --- /dev/null +++ b/changelog/3900.added.md @@ -0,0 +1 @@ +- Added `broadcast_interruption()` to `FrameProcessor`. This method pushes an `InterruptionFrame` both upstream and downstream directly from the calling processor, avoiding the round-trip through the pipeline task that `push_interruption_task_frame_and_wait()` required. diff --git a/changelog/3900.changed.md b/changelog/3900.changed.md new file mode 100644 index 000000000..59b4cdb95 --- /dev/null +++ b/changelog/3900.changed.md @@ -0,0 +1 @@ +- Removed `event` field and `complete()` method from `InterruptionFrame`. Removed `event` field from `InterruptionTaskFrame`. These are no longer needed since `broadcast_interruption()` does not require a round-trip completion signal. diff --git a/changelog/3900.deprecated.md b/changelog/3900.deprecated.md new file mode 100644 index 000000000..421e10e92 --- /dev/null +++ b/changelog/3900.deprecated.md @@ -0,0 +1 @@ +- Deprecated `push_interruption_task_frame_and_wait()` in `FrameProcessor`. Use `broadcast_interruption()` instead. The old method now delegates to `broadcast_interruption()` and logs a deprecation warning. diff --git a/src/pipecat/extensions/voicemail/voicemail_detector.py b/src/pipecat/extensions/voicemail/voicemail_detector.py index 7e22e535a..470f5dd54 100644 --- a/src/pipecat/extensions/voicemail/voicemail_detector.py +++ b/src/pipecat/extensions/voicemail/voicemail_detector.py @@ -368,7 +368,7 @@ class ClassificationProcessor(FrameProcessor): await self._voicemail_notifier.notify() # Clear buffered TTS frames # Interrupt the current pipeline to stop any ongoing processing - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() # Set the voicemail event to trigger the voicemail handler self._voicemail_event.clear() diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 126f3c001..9d6f78d6c 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -11,7 +11,6 @@ including data frames, system frames, and control frames for audio, video, text, and LLM processing. """ -import asyncio import time from dataclasses import dataclass, field from typing import ( @@ -1141,24 +1140,9 @@ class InterruptionFrame(SystemFrame): This frame is used to interrupt the pipeline. For example, when a user starts speaking to cancel any in-progress bot output. It can also be pushed by any processor. - - Parameters: - event: Optional event set when the frame has fully traversed the - pipeline. - """ - event: Optional[asyncio.Event] = None - - def complete(self): - """Signal that this interruption has been fully processed. - - Called automatically when the frame reaches the pipeline sink, or - manually when the frame is consumed before reaching it (e.g. when - the user is muted). - """ - if self.event: - self.event.set() + pass @dataclass @@ -1825,16 +1809,11 @@ class InterruptionTaskFrame(TaskFrame): """Frame indicating the pipeline should be interrupted. This frame should be pushed upstream to indicate the pipeline should be - interrupted. The pipeline task converts this into an `InterruptionFrame` and - sends it downstream. The `event` is passed to the `InterruptionFrame` so it - can signal when the interruption has fully traversed the pipeline. - - Parameters: - event: Optional event passed to the corresponding `InterruptionFrame`. - + interrupted. The pipeline task converts this into an `InterruptionFrame` + and sends it downstream. """ - event: Optional[asyncio.Event] = None + pass @dataclass diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index deae6290c..291ed5506 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -892,7 +892,7 @@ class PipelineTask(BasePipelineTask): # pipeline. This is in case the push task is blocked waiting for a # pipeline-ending frame to finish traversing the pipeline. logger.debug(f"{self}: received interruption task frame {frame}") - await self._pipeline.queue_frame(InterruptionFrame(event=frame.event)) + await self._pipeline.queue_frame(InterruptionFrame()) elif isinstance(frame, ErrorFrame): await self._call_event_handler("on_pipeline_error", frame) if frame.fatal: @@ -931,8 +931,6 @@ class PipelineTask(BasePipelineTask): self._pipeline_end_event.set() elif isinstance(frame, CancelFrame): self._pipeline_end_event.set() - elif isinstance(frame, InterruptionFrame): - frame.complete() elif isinstance(frame, HeartbeatFrame): await self._heartbeat_queue.put(frame) diff --git a/src/pipecat/processors/aggregators/dtmf_aggregator.py b/src/pipecat/processors/aggregators/dtmf_aggregator.py index 1b9c59158..ea56ba6fc 100644 --- a/src/pipecat/processors/aggregators/dtmf_aggregator.py +++ b/src/pipecat/processors/aggregators/dtmf_aggregator.py @@ -104,7 +104,7 @@ class DTMFAggregator(FrameProcessor): # For first digit, schedule interruption. if is_first_digit: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() # Check for immediate flush conditions if frame.button == self._termination_digit: diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 44e5ce252..7c246b209 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -581,7 +581,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): logger.debug( "Interruption conditions met - pushing interruption and aggregation" ) - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() await self._process_aggregation() else: logger.debug("Interruption conditions not met - not pushing aggregation") diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index 96f3702be..cf6c81e5f 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -608,12 +608,6 @@ class LLMUserAggregator(LLMContextAggregator): if should_mute_frame: logger.trace(f"{frame.name} suppressed - user currently muted") - # When muted, the InterruptionFrame won't propagate further and - # will never reach the pipeline sink. Complete it here so - # push_interruption_task_frame_and_wait() doesn't hang. - if should_mute_frame and isinstance(frame, InterruptionFrame): - frame.complete() - should_mute_next_time = False for s in self._params.user_mute_strategies: should_mute_next_time |= await s.process_frame(frame) @@ -737,7 +731,7 @@ class LLMUserAggregator(LLMContextAggregator): await self._user_idle_controller.process_frame(UserStartedSpeakingFrame()) if params.enable_interruptions and self._allow_interruptions: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() await self._call_event_handler("on_user_turn_started", strategy) diff --git a/src/pipecat/processors/filters/stt_mute_filter.py b/src/pipecat/processors/filters/stt_mute_filter.py index f5d008e28..9f522a20d 100644 --- a/src/pipecat/processors/filters/stt_mute_filter.py +++ b/src/pipecat/processors/filters/stt_mute_filter.py @@ -234,12 +234,6 @@ class STTMuteFilter(FrameProcessor): await self.push_frame(frame, direction) else: logger.trace(f"{frame.__class__.__name__} suppressed - STT currently muted") - - # When muted, the InterruptionFrame won't propagate further - # and will never reach the pipeline sink. Complete it here so - # push_interruption_task_frame_and_wait() doesn't hang. - if isinstance(frame, InterruptionFrame): - frame.complete() else: # Pass all other frames through await self.push_frame(frame, direction) diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 3e7b48442..69c503e71 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -41,7 +41,6 @@ from pipecat.frames.frames import ( FrameProcessorResumeFrame, FrameProcessorResumeUrgentFrame, InterruptionFrame, - InterruptionTaskFrame, StartFrame, SystemFrame, UninterruptibleFrame, @@ -240,10 +239,6 @@ class FrameProcessor(BaseObject): self.__process_frame_task: Optional[asyncio.Task] = None self.__process_current_frame: Optional[Frame] = None - # Set while awaiting push_interruption_task_frame_and_wait() so that - # _start_interruption() knows not to cancel the process task. - self._wait_for_interruption = False - # Frame processor events. self._register_event_handler("on_before_process_frame", sync=True) self._register_event_handler("on_after_process_frame", sync=True) @@ -329,7 +324,7 @@ class FrameProcessor(BaseObject): warnings.simplefilter("always") warnings.warn( "`FrameProcessor.interruptions_allowed` is deprecated. " - "Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.", + "Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.", DeprecationWarning, stacklevel=2, ) @@ -647,15 +642,6 @@ class FrameProcessor(BaseObject): if self._cancelling: return - # If we are waiting for an interruption, bypass all queued system frames - # and process the frame right away. This is because a previous system - # frame might be waiting for the interruption frame blocking the input - # task, so this InterruptionFrame would never be dequeued and we'd - # deadlock. - if self._wait_for_interruption and isinstance(frame, InterruptionFrame): - await self.__process_frame(frame, direction, callback) - return - if self._enable_direct_mode: await self.__process_frame(frame, direction, callback) else: @@ -790,43 +776,32 @@ class FrameProcessor(BaseObject): await self._call_event_handler("on_after_push_frame", frame) + async def broadcast_interruption(self): + """Broadcast an `InterruptionFrame` both upstream and downstream.""" + logger.debug(f"{self}: broadcasting interruption") + self.__reset_process_task() + await self.stop_all_metrics() + await self.broadcast_frame(InterruptionFrame) + async def push_interruption_task_frame_and_wait(self, *, timeout: float = 5.0): """Push an interruption task frame upstream and wait for the interruption. - This function sends an `InterruptionTaskFrame` upstream to the - pipeline task. The task creates a corresponding `InterruptionFrame` - and sends it downstream through the pipeline. An `asyncio.Event` is - attached to both frames so the caller can wait until the interruption - has fully traversed the pipeline. The event is set when the - `InterruptionFrame` reaches the pipeline sink. If the frame does - not complete within the given timeout, a warning is logged and the - event is forcibly set so the caller is unblocked. - - Args: - timeout: Maximum seconds to wait for the interruption to complete. + .. deprecated:: 0.0.104 + Use :meth:`broadcast_interruption` instead. This method now + delegates to ``broadcast_interruption()`` and ignores *timeout*. """ - self._wait_for_interruption = True + import warnings - event = asyncio.Event() + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`FrameProcessor.push_interruption_task_frame_and_wait()` is deprecated. " + "Use `FrameProcessor.broadcast_interruption()` instead.", + DeprecationWarning, + stacklevel=2, + ) - await self.push_frame(InterruptionTaskFrame(event=event), FrameDirection.UPSTREAM) - - # Wait for the `InterruptionFrame` to complete and log a warning if it - # takes too long. If it does take too long make sure we unblock it, - # otherwise we will hang here forever. - while not event.is_set(): - try: - await asyncio.wait_for(event.wait(), timeout=timeout) - except asyncio.TimeoutError: - logger.warning( - f"{self}: InterruptionFrame has not completed after" - f" {timeout}s. Make sure InterruptionFrame.complete()" - " is being called (e.g. if the frame is being blocked" - " or consumed before reaching the pipeline sink)." - ) - event.set() - - self._wait_for_interruption = False + await self.broadcast_interruption() async def broadcast_frame(self, frame_cls: Type[Frame], **kwargs): """Broadcasts a frame of the specified class upstream and downstream. @@ -933,15 +908,7 @@ class FrameProcessor(BaseObject): async def _start_interruption(self): """Start handling an interruption by cancelling current tasks.""" try: - if self._wait_for_interruption: - # If we get here we know the process task was just waiting for - # an interruption (push_interruption_task_frame_and_wait()), so - # we can't cancel the task because it might still need to do - # more things (e.g. pushing a frame after the - # interruption). Instead we just drain the queue because this is - # an interruption. - self.__reset_process_task() - elif isinstance(self.__process_current_frame, UninterruptibleFrame): + if isinstance(self.__process_current_frame, UninterruptibleFrame): # We don't want to cancel UninterruptibleFrame, so we simply # cleanup the queue. self.__reset_process_queue() diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index e01e95714..eb1e79f3e 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -1702,7 +1702,7 @@ class RTVIProcessor(FrameProcessor): async def interrupt_bot(self): """Send a bot interruption frame upstream.""" - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() async def send_server_message(self, data: Any): """Send a server message to the client.""" diff --git a/src/pipecat/services/deepgram/flux/stt.py b/src/pipecat/services/deepgram/flux/stt.py index d509b267e..984906c6c 100644 --- a/src/pipecat/services/deepgram/flux/stt.py +++ b/src/pipecat/services/deepgram/flux/stt.py @@ -675,7 +675,7 @@ class DeepgramFluxSTTService(WebsocketSTTService): self._user_is_speaking = True await self.broadcast_frame(UserStartedSpeakingFrame) if self._should_interrupt: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() await self.start_metrics() await self._call_event_handler("on_start_of_turn", transcript) if transcript: diff --git a/src/pipecat/services/deepgram/stt.py b/src/pipecat/services/deepgram/stt.py index 497d6aae1..8eb246cf2 100644 --- a/src/pipecat/services/deepgram/stt.py +++ b/src/pipecat/services/deepgram/stt.py @@ -471,7 +471,7 @@ class DeepgramSTTService(STTService): await self._call_event_handler("on_speech_started", *args, **kwargs) await self.broadcast_frame(UserStartedSpeakingFrame) if self._should_interrupt: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() async def _on_utterance_end(self, *args, **kwargs): await self._call_event_handler("on_utterance_end", *args, **kwargs) diff --git a/src/pipecat/services/gladia/stt.py b/src/pipecat/services/gladia/stt.py index 045a56613..bba554b4a 100644 --- a/src/pipecat/services/gladia/stt.py +++ b/src/pipecat/services/gladia/stt.py @@ -613,7 +613,7 @@ class GladiaSTTService(WebsocketSTTService): await self.broadcast_frame(UserStartedSpeakingFrame) if self._should_interrupt: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() async def _on_speech_ended(self): """Handle speech end event from Gladia. diff --git a/src/pipecat/services/google/gemini_live/llm.py b/src/pipecat/services/google/gemini_live/llm.py index d06f941c7..2ed11c739 100644 --- a/src/pipecat/services/google/gemini_live/llm.py +++ b/src/pipecat/services/google/gemini_live/llm.py @@ -1265,7 +1265,7 @@ class GeminiLiveLLMService(LLMService): # combination with the context aggregator default # turn strategies. logger.debug("Gemini VAD: interrupted signal received") - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() elif message.server_content and message.server_content.model_turn: await self._handle_msg_model_turn(message) elif ( diff --git a/src/pipecat/services/grok/realtime/llm.py b/src/pipecat/services/grok/realtime/llm.py index 6d148f6d7..7a4e73806 100644 --- a/src/pipecat/services/grok/realtime/llm.py +++ b/src/pipecat/services/grok/realtime/llm.py @@ -734,7 +734,7 @@ class GrokRealtimeLLMService(LLMService): """Handle speech started event from VAD.""" await self._truncate_current_audio_response() await self.broadcast_frame(UserStartedSpeakingFrame) - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() async def _handle_evt_speech_stopped(self, evt): """Handle speech stopped event from VAD.""" diff --git a/src/pipecat/services/openai/realtime/llm.py b/src/pipecat/services/openai/realtime/llm.py index a6667c7c8..07b6aa82b 100644 --- a/src/pipecat/services/openai/realtime/llm.py +++ b/src/pipecat/services/openai/realtime/llm.py @@ -839,7 +839,7 @@ class OpenAIRealtimeLLMService(LLMService): async def _handle_evt_speech_started(self, evt): await self._truncate_current_audio_response() await self.broadcast_frame(UserStartedSpeakingFrame) - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() async def _handle_evt_speech_stopped(self, evt): await self.start_ttfb_metrics() diff --git a/src/pipecat/services/openai/stt.py b/src/pipecat/services/openai/stt.py index 9a52be114..32895f8b5 100644 --- a/src/pipecat/services/openai/stt.py +++ b/src/pipecat/services/openai/stt.py @@ -639,7 +639,7 @@ class OpenAIRealtimeSTTService(WebsocketSTTService): logger.debug("Server VAD: speech started") await self.broadcast_frame(UserStartedSpeakingFrame) if self._should_interrupt: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() await self.start_processing_metrics() async def _handle_speech_stopped(self, evt: dict): diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index 8614713ff..c912ed45c 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -709,7 +709,7 @@ class OpenAIRealtimeBetaLLMService(LLMService): async def _handle_evt_speech_started(self, evt): await self._truncate_current_audio_response() await self.broadcast_frame(UserStartedSpeakingFrame) - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() async def _handle_evt_speech_stopped(self, evt): await self.start_ttfb_metrics() diff --git a/src/pipecat/services/sarvam/stt.py b/src/pipecat/services/sarvam/stt.py index 9e245aece..e368ceb02 100644 --- a/src/pipecat/services/sarvam/stt.py +++ b/src/pipecat/services/sarvam/stt.py @@ -644,7 +644,7 @@ class SarvamSTTService(STTService): logger.debug("User started speaking") await self._call_event_handler("on_speech_started") await self.broadcast_frame(UserStartedSpeakingFrame) - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() elif signal == "END_SPEECH": logger.debug("User stopped speaking") diff --git a/src/pipecat/services/speechmatics/stt.py b/src/pipecat/services/speechmatics/stt.py index ac18a36e3..bdeb3b249 100644 --- a/src/pipecat/services/speechmatics/stt.py +++ b/src/pipecat/services/speechmatics/stt.py @@ -836,7 +836,7 @@ class SpeechmaticsSTTService(STTService): # await self.start_processing_metrics() await self.broadcast_frame(UserStartedSpeakingFrame) if self._should_interrupt: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() async def _handle_end_of_turn(self, message: dict[str, Any]) -> None: """Handle EndOfTurn events. diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 49c28149a..1da672ab7 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -558,7 +558,7 @@ class BaseInputTransport(FrameProcessor): # Make sure we notify about interruptions quickly out-of-band. if should_push_immediate_interruption and self._allow_interruptions: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() elif self.interruption_strategies and self._bot_speaking: logger.debug( "User started speaking while bot is speaking with interruption config - " diff --git a/src/pipecat/turns/user_turn_processor.py b/src/pipecat/turns/user_turn_processor.py index 7f8995202..85bc658dd 100644 --- a/src/pipecat/turns/user_turn_processor.py +++ b/src/pipecat/turns/user_turn_processor.py @@ -182,7 +182,7 @@ class UserTurnProcessor(FrameProcessor): await self._user_idle_controller.process_frame(UserStartedSpeakingFrame()) if params.enable_interruptions and self._allow_interruptions: - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() await self._call_event_handler("on_user_turn_started", strategy) diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index 24dae0b4c..37d36bfef 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -21,7 +21,6 @@ from pipecat.frames.frames import ( FunctionCallResultProperties, InterimTranscriptionFrame, InterruptionFrame, - InterruptionTaskFrame, LLMContextAssistantTimestampFrame, LLMContextFrame, LLMFullResponseEndFrame, @@ -567,7 +566,7 @@ class BaseTestUserContextAggregator: SleepFrame(), UserStoppedSpeakingFrame(), ] - expected_up_frames = [InterruptionTaskFrame] + expected_up_frames = [InterruptionFrame] expected_down_frames = [ BotStartedSpeakingFrame, UserStartedSpeakingFrame, diff --git a/tests/test_frame_processor.py b/tests/test_frame_processor.py index 138c8e6d8..a875741e3 100644 --- a/tests/test_frame_processor.py +++ b/tests/test_frame_processor.py @@ -9,8 +9,6 @@ import unittest from dataclasses import dataclass, field from typing import List -from loguru import logger - from pipecat.frames.frames import ( DataFrame, EndFrame, @@ -85,50 +83,38 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase): assert before_push_called assert after_push_called - async def test_interruption_and_wait(self): - class DelayFrameProcessor(FrameProcessor): - """This processors just gives time to the event loop to change - between tasks. Otherwise things happen to fast.""" - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - await asyncio.sleep(0.1) - await self.push_frame(frame, direction) + async def test_broadcast_interruption(self): + """Test that broadcast_interruption() pushes InterruptionFrame both + directions and allows subsequent code to run.""" class InterruptFrameProcessor(FrameProcessor): async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) if isinstance(frame, TextFrame): - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() await self.push_frame(OutputTransportMessageUrgentFrame(message=frame.text)) else: await self.push_frame(frame, direction) - pipeline = Pipeline([DelayFrameProcessor(), InterruptFrameProcessor()]) + pipeline = Pipeline([InterruptFrameProcessor()]) frames_to_send = [ - # Just a random interruption to make sure we don't clear anything - # before the actual `InterruptionTaskFrame` interruption. - InterruptionFrame(), - # This will generate an `InterruptionTaskFrame` and will wait for an - # `InterruptionFrame`. TextFrame(text="Hello from Pipecat!"), - # Just give time for everything to complete. SleepFrame(sleep=0.5), - EndFrame(), ] expected_down_frames = [ - InterruptionFrame, InterruptionFrame, OutputTransportMessageUrgentFrame, - EndFrame, + ] + expected_up_frames = [ + InterruptionFrame, ] await run_test( pipeline, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, - send_end_frame=False, + expected_up_frames=expected_up_frames, ) async def test_interruptible_frames(self): @@ -454,33 +440,20 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase): stop_frames = [f for f in received_frames if isinstance(f, StopFrame)] self.assertEqual(len(stop_frames), 1, "StopFrame should survive interruption") - async def test_interruption_frame_complete_sets_event(self): - """Test that InterruptionFrame.complete() sets the event.""" - event = asyncio.Event() - frame = InterruptionFrame(event=event) - self.assertFalse(event.is_set()) - frame.complete() - self.assertTrue(event.is_set()) - - async def test_interruption_frame_complete_without_event(self): - """Test that InterruptionFrame.complete() is safe without an event.""" - frame = InterruptionFrame() - frame.complete() # Should not raise - - async def test_interruption_event_set_at_pipeline_sink(self): - """Test that the event from push_interruption_task_frame_and_wait() - is set when the InterruptionFrame reaches the pipeline sink.""" - event_was_set = False + async def test_broadcast_interruption_allows_subsequent_code(self): + """Test that broadcast_interruption() returns immediately, allowing the + caller to run code afterwards (e.g. push an urgent frame).""" + code_after_ran = False class InterruptOnTextProcessor(FrameProcessor): async def process_frame(self, frame: Frame, direction: FrameDirection): - nonlocal event_was_set + nonlocal code_after_ran await super().process_frame(frame, direction) if isinstance(frame, TextFrame): - await self.push_interruption_task_frame_and_wait() + await self.broadcast_interruption() - event_was_set = True + code_after_ran = True await self.push_frame(OutputTransportMessageUrgentFrame(message="done")) else: await self.push_frame(frame, direction) @@ -499,63 +472,7 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase): frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - self.assertTrue(event_was_set, "Event should be set after InterruptionFrame completes") - - async def test_interruption_completion_timeout_warning(self): - """Test that a warning is logged when an InterruptionFrame is blocked - and never reaches the pipeline sink.""" - warnings = [] - handler_id = logger.add( - lambda msg: warnings.append(str(msg)), level="WARNING", format="{message}" - ) - - try: - - class BlockInterruptionProcessor(FrameProcessor): - """Blocks InterruptionFrames, completing them after a delay.""" - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, InterruptionFrame): - # Complete after the timeout so the warning fires - # but the test doesn't hang. - async def delayed_complete(): - await asyncio.sleep(1.0) - frame.complete() - - asyncio.create_task(delayed_complete()) - return - await self.push_frame(frame, direction) - - class InterruptOnTextProcessor(FrameProcessor): - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, TextFrame): - await self.push_interruption_task_frame_and_wait(timeout=0.5) - await self.push_frame(OutputTransportMessageUrgentFrame(message="done")) - else: - await self.push_frame(frame, direction) - - pipeline = Pipeline([BlockInterruptionProcessor(), InterruptOnTextProcessor()]) - - frames_to_send = [ - TextFrame(text="trigger"), - ] - expected_down_frames = [ - OutputTransportMessageUrgentFrame, - ] - await run_test( - pipeline, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - finally: - logger.remove(handler_id) - - self.assertTrue( - any("InterruptionFrame has not completed" in w for w in warnings), - "Expected a timeout warning about InterruptionFrame not completing", - ) + self.assertTrue(code_after_ran, "Code after broadcast_interruption() should execute") if __name__ == "__main__": diff --git a/tests/test_stt_mute_filter.py b/tests/test_stt_mute_filter.py index adf4611df..8f55bdecb 100644 --- a/tests/test_stt_mute_filter.py +++ b/tests/test_stt_mute_filter.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio import unittest from pipecat.frames.frames import ( @@ -329,17 +328,13 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase): expected_down_frames=expected_returned_frames, ) - async def test_interruption_frame_completed_when_muted(self): - """Test that InterruptionFrame.complete() is called when the frame is - suppressed due to muting, so push_interruption_task_frame_and_wait() - doesn't hang.""" + async def test_interruption_frame_suppressed_when_muted(self): + """Test that InterruptionFrame is suppressed when the filter is muted.""" filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS})) - event = asyncio.Event() - frames_to_send = [ BotStartedSpeakingFrame(), - InterruptionFrame(event=event), + InterruptionFrame(), BotStoppedSpeakingFrame(), ] @@ -354,8 +349,6 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase): expected_down_frames=expected_returned_frames, ) - self.assertTrue(event.is_set(), "InterruptionFrame.complete() should be called when muted") - if __name__ == "__main__": unittest.main()