Add broadcast_interruption() to FrameProcessor
Replace the round-trip push_interruption_task_frame_and_wait() mechanism with broadcast_interruption(), which pushes an InterruptionFrame both upstream and downstream directly from the calling processor. This eliminates race conditions (transcription arriving before the InterruptionFrame comes back), swallowed-event timeouts (frame blocked before reaching the sink), and the complexity of _wait_for_interruption flag / queue bypass / frame.complete() obligations. - Add broadcast_interruption() to FrameProcessor - Deprecate push_interruption_task_frame_and_wait() (delegates to new method) - Remove event field and complete() from InterruptionFrame/InterruptionTaskFrame - Remove _wait_for_interruption flag and all special-case logic - Remove frame.complete() calls in stt_mute_filter and llm_response_universal - Update all 17 call sites to use broadcast_interruption() - Update tests
This commit is contained in:
1
changelog/3900.added.md
Normal file
1
changelog/3900.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `broadcast_interruption()` to `FrameProcessor`. This method pushes an `InterruptionFrame` both upstream and downstream directly from the calling processor, avoiding the round-trip through the pipeline task that `push_interruption_task_frame_and_wait()` required.
|
||||
1
changelog/3900.changed.md
Normal file
1
changelog/3900.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Removed `event` field and `complete()` method from `InterruptionFrame`. Removed `event` field from `InterruptionTaskFrame`. These are no longer needed since `broadcast_interruption()` does not require a round-trip completion signal.
|
||||
1
changelog/3900.deprecated.md
Normal file
1
changelog/3900.deprecated.md
Normal file
@@ -0,0 +1 @@
|
||||
- Deprecated `push_interruption_task_frame_and_wait()` in `FrameProcessor`. Use `broadcast_interruption()` instead. The old method now delegates to `broadcast_interruption()` and logs a deprecation warning.
|
||||
@@ -368,7 +368,7 @@ class ClassificationProcessor(FrameProcessor):
|
||||
await self._voicemail_notifier.notify() # Clear buffered TTS frames
|
||||
|
||||
# Interrupt the current pipeline to stop any ongoing processing
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
# Set the voicemail event to trigger the voicemail handler
|
||||
self._voicemail_event.clear()
|
||||
|
||||
@@ -11,7 +11,6 @@ including data frames, system frames, and control frames for audio, video, text,
|
||||
and LLM processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
@@ -1141,24 +1140,9 @@ class InterruptionFrame(SystemFrame):
|
||||
This frame is used to interrupt the pipeline. For example, when a user
|
||||
starts speaking to cancel any in-progress bot output. It can also be pushed
|
||||
by any processor.
|
||||
|
||||
Parameters:
|
||||
event: Optional event set when the frame has fully traversed the
|
||||
pipeline.
|
||||
|
||||
"""
|
||||
|
||||
event: Optional[asyncio.Event] = None
|
||||
|
||||
def complete(self):
|
||||
"""Signal that this interruption has been fully processed.
|
||||
|
||||
Called automatically when the frame reaches the pipeline sink, or
|
||||
manually when the frame is consumed before reaching it (e.g. when
|
||||
the user is muted).
|
||||
"""
|
||||
if self.event:
|
||||
self.event.set()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1825,16 +1809,11 @@ class InterruptionTaskFrame(TaskFrame):
|
||||
"""Frame indicating the pipeline should be interrupted.
|
||||
|
||||
This frame should be pushed upstream to indicate the pipeline should be
|
||||
interrupted. The pipeline task converts this into an `InterruptionFrame` and
|
||||
sends it downstream. The `event` is passed to the `InterruptionFrame` so it
|
||||
can signal when the interruption has fully traversed the pipeline.
|
||||
|
||||
Parameters:
|
||||
event: Optional event passed to the corresponding `InterruptionFrame`.
|
||||
|
||||
interrupted. The pipeline task converts this into an `InterruptionFrame`
|
||||
and sends it downstream.
|
||||
"""
|
||||
|
||||
event: Optional[asyncio.Event] = None
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -892,7 +892,7 @@ class PipelineTask(BasePipelineTask):
|
||||
# pipeline. This is in case the push task is blocked waiting for a
|
||||
# pipeline-ending frame to finish traversing the pipeline.
|
||||
logger.debug(f"{self}: received interruption task frame {frame}")
|
||||
await self._pipeline.queue_frame(InterruptionFrame(event=frame.event))
|
||||
await self._pipeline.queue_frame(InterruptionFrame())
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
await self._call_event_handler("on_pipeline_error", frame)
|
||||
if frame.fatal:
|
||||
@@ -931,8 +931,6 @@ class PipelineTask(BasePipelineTask):
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
await self._heartbeat_queue.put(frame)
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class DTMFAggregator(FrameProcessor):
|
||||
|
||||
# For first digit, schedule interruption.
|
||||
if is_first_digit:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
# Check for immediate flush conditions
|
||||
if frame.button == self._termination_digit:
|
||||
|
||||
@@ -581,7 +581,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing interruption and aggregation"
|
||||
)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
|
||||
@@ -608,12 +608,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if should_mute_frame:
|
||||
logger.trace(f"{frame.name} suppressed - user currently muted")
|
||||
|
||||
# When muted, the InterruptionFrame won't propagate further and
|
||||
# will never reach the pipeline sink. Complete it here so
|
||||
# push_interruption_task_frame_and_wait() doesn't hang.
|
||||
if should_mute_frame and isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
|
||||
should_mute_next_time = False
|
||||
for s in self._params.user_mute_strategies:
|
||||
should_mute_next_time |= await s.process_frame(frame)
|
||||
@@ -737,7 +731,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._user_idle_controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
|
||||
@@ -234,12 +234,6 @@ class STTMuteFilter(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
logger.trace(f"{frame.__class__.__name__} suppressed - STT currently muted")
|
||||
|
||||
# When muted, the InterruptionFrame won't propagate further
|
||||
# and will never reach the pipeline sink. Complete it here so
|
||||
# push_interruption_task_frame_and_wait() doesn't hang.
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
else:
|
||||
# Pass all other frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -41,7 +41,6 @@ from pipecat.frames.frames import (
|
||||
FrameProcessorResumeFrame,
|
||||
FrameProcessorResumeUrgentFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
UninterruptibleFrame,
|
||||
@@ -240,10 +239,6 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
self.__process_current_frame: Optional[Frame] = None
|
||||
|
||||
# Set while awaiting push_interruption_task_frame_and_wait() so that
|
||||
# _start_interruption() knows not to cancel the process task.
|
||||
self._wait_for_interruption = False
|
||||
|
||||
# Frame processor events.
|
||||
self._register_event_handler("on_before_process_frame", sync=True)
|
||||
self._register_event_handler("on_after_process_frame", sync=True)
|
||||
@@ -329,7 +324,7 @@ class FrameProcessor(BaseObject):
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FrameProcessor.interruptions_allowed` is deprecated. "
|
||||
"Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.",
|
||||
"Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -647,15 +642,6 @@ class FrameProcessor(BaseObject):
|
||||
if self._cancelling:
|
||||
return
|
||||
|
||||
# If we are waiting for an interruption, bypass all queued system frames
|
||||
# and process the frame right away. This is because a previous system
|
||||
# frame might be waiting for the interruption frame blocking the input
|
||||
# task, so this InterruptionFrame would never be dequeued and we'd
|
||||
# deadlock.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
return
|
||||
|
||||
if self._enable_direct_mode:
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
else:
|
||||
@@ -790,43 +776,32 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self._call_event_handler("on_after_push_frame", frame)
|
||||
|
||||
async def broadcast_interruption(self):
|
||||
"""Broadcast an `InterruptionFrame` both upstream and downstream."""
|
||||
logger.debug(f"{self}: broadcasting interruption")
|
||||
self.__reset_process_task()
|
||||
await self.stop_all_metrics()
|
||||
await self.broadcast_frame(InterruptionFrame)
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self, *, timeout: float = 5.0):
|
||||
"""Push an interruption task frame upstream and wait for the interruption.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream to the
|
||||
pipeline task. The task creates a corresponding `InterruptionFrame`
|
||||
and sends it downstream through the pipeline. An `asyncio.Event` is
|
||||
attached to both frames so the caller can wait until the interruption
|
||||
has fully traversed the pipeline. The event is set when the
|
||||
`InterruptionFrame` reaches the pipeline sink. If the frame does
|
||||
not complete within the given timeout, a warning is logged and the
|
||||
event is forcibly set so the caller is unblocked.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for the interruption to complete.
|
||||
.. deprecated:: 0.0.104
|
||||
Use :meth:`broadcast_interruption` instead. This method now
|
||||
delegates to ``broadcast_interruption()`` and ignores *timeout*.
|
||||
"""
|
||||
self._wait_for_interruption = True
|
||||
import warnings
|
||||
|
||||
event = asyncio.Event()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FrameProcessor.push_interruption_task_frame_and_wait()` is deprecated. "
|
||||
"Use `FrameProcessor.broadcast_interruption()` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
await self.push_frame(InterruptionTaskFrame(event=event), FrameDirection.UPSTREAM)
|
||||
|
||||
# Wait for the `InterruptionFrame` to complete and log a warning if it
|
||||
# takes too long. If it does take too long make sure we unblock it,
|
||||
# otherwise we will hang here forever.
|
||||
while not event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"{self}: InterruptionFrame has not completed after"
|
||||
f" {timeout}s. Make sure InterruptionFrame.complete()"
|
||||
" is being called (e.g. if the frame is being blocked"
|
||||
" or consumed before reaching the pipeline sink)."
|
||||
)
|
||||
event.set()
|
||||
|
||||
self._wait_for_interruption = False
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def broadcast_frame(self, frame_cls: Type[Frame], **kwargs):
|
||||
"""Broadcasts a frame of the specified class upstream and downstream.
|
||||
@@ -933,15 +908,7 @@ class FrameProcessor(BaseObject):
|
||||
async def _start_interruption(self):
|
||||
"""Start handling an interruption by cancelling current tasks."""
|
||||
try:
|
||||
if self._wait_for_interruption:
|
||||
# If we get here we know the process task was just waiting for
|
||||
# an interruption (push_interruption_task_frame_and_wait()), so
|
||||
# we can't cancel the task because it might still need to do
|
||||
# more things (e.g. pushing a frame after the
|
||||
# interruption). Instead we just drain the queue because this is
|
||||
# an interruption.
|
||||
self.__reset_process_task()
|
||||
elif isinstance(self.__process_current_frame, UninterruptibleFrame):
|
||||
if isinstance(self.__process_current_frame, UninterruptibleFrame):
|
||||
# We don't want to cancel UninterruptibleFrame, so we simply
|
||||
# cleanup the queue.
|
||||
self.__reset_process_queue()
|
||||
|
||||
@@ -1702,7 +1702,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def interrupt_bot(self):
|
||||
"""Send a bot interruption frame upstream."""
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def send_server_message(self, data: Any):
|
||||
"""Send a server message to the client."""
|
||||
|
||||
@@ -675,7 +675,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
self._user_is_speaking = True
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
await self.start_metrics()
|
||||
await self._call_event_handler("on_start_of_turn", transcript)
|
||||
if transcript:
|
||||
|
||||
@@ -471,7 +471,7 @@ class DeepgramSTTService(STTService):
|
||||
await self._call_event_handler("on_speech_started", *args, **kwargs)
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _on_utterance_end(self, *args, **kwargs):
|
||||
await self._call_event_handler("on_utterance_end", *args, **kwargs)
|
||||
|
||||
@@ -613,7 +613,7 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _on_speech_ended(self):
|
||||
"""Handle speech end event from Gladia.
|
||||
|
||||
@@ -1265,7 +1265,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
# combination with the context aggregator default
|
||||
# turn strategies.
|
||||
logger.debug("Gemini VAD: interrupted signal received")
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
elif message.server_content and message.server_content.model_turn:
|
||||
await self._handle_msg_model_turn(message)
|
||||
elif (
|
||||
|
||||
@@ -734,7 +734,7 @@ class GrokRealtimeLLMService(LLMService):
|
||||
"""Handle speech started event from VAD."""
|
||||
await self._truncate_current_audio_response()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
"""Handle speech stopped event from VAD."""
|
||||
|
||||
@@ -839,7 +839,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -639,7 +639,7 @@ class OpenAIRealtimeSTTService(WebsocketSTTService):
|
||||
logger.debug("Server VAD: speech started")
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def _handle_speech_stopped(self, evt: dict):
|
||||
|
||||
@@ -709,7 +709,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -644,7 +644,7 @@ class SarvamSTTService(STTService):
|
||||
logger.debug("User started speaking")
|
||||
await self._call_event_handler("on_speech_started")
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
elif signal == "END_SPEECH":
|
||||
logger.debug("User stopped speaking")
|
||||
|
||||
@@ -836,7 +836,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
# await self.start_processing_metrics()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
async def _handle_end_of_turn(self, message: dict[str, Any]) -> None:
|
||||
"""Handle EndOfTurn events.
|
||||
|
||||
@@ -558,7 +558,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
if should_push_immediate_interruption and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
elif self.interruption_strategies and self._bot_speaking:
|
||||
logger.debug(
|
||||
"User started speaking while bot is speaking with interruption config - "
|
||||
|
||||
@@ -182,7 +182,7 @@ class UserTurnProcessor(FrameProcessor):
|
||||
await self._user_idle_controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultProperties,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -567,7 +566,7 @@ class BaseTestUserContextAggregator:
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_up_frames = [InterruptionTaskFrame]
|
||||
expected_up_frames = [InterruptionFrame]
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
|
||||
@@ -9,8 +9,6 @@ import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
DataFrame,
|
||||
EndFrame,
|
||||
@@ -85,50 +83,38 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
assert before_push_called
|
||||
assert after_push_called
|
||||
|
||||
async def test_interruption_and_wait(self):
|
||||
class DelayFrameProcessor(FrameProcessor):
|
||||
"""This processors just gives time to the event loop to change
|
||||
between tasks. Otherwise things happen to fast."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await asyncio.sleep(0.1)
|
||||
await self.push_frame(frame, direction)
|
||||
async def test_broadcast_interruption(self):
|
||||
"""Test that broadcast_interruption() pushes InterruptionFrame both
|
||||
directions and allows subsequent code to run."""
|
||||
|
||||
class InterruptFrameProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message=frame.text))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
pipeline = Pipeline([DelayFrameProcessor(), InterruptFrameProcessor()])
|
||||
pipeline = Pipeline([InterruptFrameProcessor()])
|
||||
|
||||
frames_to_send = [
|
||||
# Just a random interruption to make sure we don't clear anything
|
||||
# before the actual `InterruptionTaskFrame` interruption.
|
||||
InterruptionFrame(),
|
||||
# This will generate an `InterruptionTaskFrame` and will wait for an
|
||||
# `InterruptionFrame`.
|
||||
TextFrame(text="Hello from Pipecat!"),
|
||||
# Just give time for everything to complete.
|
||||
SleepFrame(sleep=0.5),
|
||||
EndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
InterruptionFrame,
|
||||
InterruptionFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
EndFrame,
|
||||
]
|
||||
expected_up_frames = [
|
||||
InterruptionFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
send_end_frame=False,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
|
||||
async def test_interruptible_frames(self):
|
||||
@@ -454,33 +440,20 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
stop_frames = [f for f in received_frames if isinstance(f, StopFrame)]
|
||||
self.assertEqual(len(stop_frames), 1, "StopFrame should survive interruption")
|
||||
|
||||
async def test_interruption_frame_complete_sets_event(self):
|
||||
"""Test that InterruptionFrame.complete() sets the event."""
|
||||
event = asyncio.Event()
|
||||
frame = InterruptionFrame(event=event)
|
||||
self.assertFalse(event.is_set())
|
||||
frame.complete()
|
||||
self.assertTrue(event.is_set())
|
||||
|
||||
async def test_interruption_frame_complete_without_event(self):
|
||||
"""Test that InterruptionFrame.complete() is safe without an event."""
|
||||
frame = InterruptionFrame()
|
||||
frame.complete() # Should not raise
|
||||
|
||||
async def test_interruption_event_set_at_pipeline_sink(self):
|
||||
"""Test that the event from push_interruption_task_frame_and_wait()
|
||||
is set when the InterruptionFrame reaches the pipeline sink."""
|
||||
event_was_set = False
|
||||
async def test_broadcast_interruption_allows_subsequent_code(self):
|
||||
"""Test that broadcast_interruption() returns immediately, allowing the
|
||||
caller to run code afterwards (e.g. push an urgent frame)."""
|
||||
code_after_ran = False
|
||||
|
||||
class InterruptOnTextProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
nonlocal event_was_set
|
||||
nonlocal code_after_ran
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_interruption()
|
||||
|
||||
event_was_set = True
|
||||
code_after_ran = True
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -499,63 +472,7 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.assertTrue(event_was_set, "Event should be set after InterruptionFrame completes")
|
||||
|
||||
async def test_interruption_completion_timeout_warning(self):
|
||||
"""Test that a warning is logged when an InterruptionFrame is blocked
|
||||
and never reaches the pipeline sink."""
|
||||
warnings = []
|
||||
handler_id = logger.add(
|
||||
lambda msg: warnings.append(str(msg)), level="WARNING", format="{message}"
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
class BlockInterruptionProcessor(FrameProcessor):
|
||||
"""Blocks InterruptionFrames, completing them after a delay."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
# Complete after the timeout so the warning fires
|
||||
# but the test doesn't hang.
|
||||
async def delayed_complete():
|
||||
await asyncio.sleep(1.0)
|
||||
frame.complete()
|
||||
|
||||
asyncio.create_task(delayed_complete())
|
||||
return
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class InterruptOnTextProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait(timeout=0.5)
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
pipeline = Pipeline([BlockInterruptionProcessor(), InterruptOnTextProcessor()])
|
||||
|
||||
frames_to_send = [
|
||||
TextFrame(text="trigger"),
|
||||
]
|
||||
expected_down_frames = [
|
||||
OutputTransportMessageUrgentFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
finally:
|
||||
logger.remove(handler_id)
|
||||
|
||||
self.assertTrue(
|
||||
any("InterruptionFrame has not completed" in w for w in warnings),
|
||||
"Expected a timeout warning about InterruptionFrame not completing",
|
||||
)
|
||||
self.assertTrue(code_after_ran, "Code after broadcast_interruption() should execute")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -329,17 +328,13 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_interruption_frame_completed_when_muted(self):
|
||||
"""Test that InterruptionFrame.complete() is called when the frame is
|
||||
suppressed due to muting, so push_interruption_task_frame_and_wait()
|
||||
doesn't hang."""
|
||||
async def test_interruption_frame_suppressed_when_muted(self):
|
||||
"""Test that InterruptionFrame is suppressed when the filter is muted."""
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
event = asyncio.Event()
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
InterruptionFrame(event=event),
|
||||
InterruptionFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
@@ -354,8 +349,6 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
self.assertTrue(event.is_set(), "InterruptionFrame.complete() should be called when muted")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user