Merge pull request #3660 from pipecat-ai/aleix/interruption-frame-completion-event
Attach asyncio.Event to InterruptionFrame for completion signaling
This commit is contained in:
@@ -68,7 +68,7 @@ Transport Input → Pipeline Source → [Processor1] → [Processor2] → ...
|
||||
|
||||
- **User turn strategies**: Detection of when the user starts and stops speaking is done via user turn start/stop strategies. They push `UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame` respectively.
|
||||
|
||||
- **Interruptions**: Interruptions are usually triggered by a user turn start strategy (e.g. `VADUserTurnStartStrategy`) but they can be triggered by other processors as well, in which case the user turn start strategies don't need to.
|
||||
- **Interruptions**: Interruptions are usually triggered by a user turn start strategy (e.g. `VADUserTurnStartStrategy`) but they can be triggered by other processors as well, in which case the user turn start strategies don't need to. An `InterruptionFrame` carries an optional `asyncio.Event` that is set when the frame reaches the pipeline sink. If a processor stops an `InterruptionFrame` from propagating downstream (i.e., doesn't push it), it **must** call `frame.complete()` to avoid stalling `push_interruption_task_frame_and_wait()` callers.
|
||||
|
||||
- **Uninterruptible Frames**: These are frames that will not be removed from internal queues even if there's an interruption. For example, `EndFrame` and `StopFrame`.
|
||||
|
||||
|
||||
1
changelog/3660.changed.md
Normal file
1
changelog/3660.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Moved interruption wait event from per-processor instance state to `InterruptionFrame` itself. Added `InterruptionFrame.complete()` to signal when the interruption has fully traversed the pipeline. Custom processors that block or consume an `InterruptionFrame` before it reaches the pipeline sink must call `frame.complete()` to avoid stalling `push_interruption_task_frame_and_wait()`. A warning is logged if completion does not happen within 2 seconds.
|
||||
@@ -11,6 +11,7 @@ including data frames, system frames, and control frames for audio, video, text,
|
||||
and LLM processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@@ -1118,15 +1119,29 @@ class FrameProcessorResumeUrgentFrame(SystemFrame):
|
||||
|
||||
@dataclass
|
||||
class InterruptionFrame(SystemFrame):
|
||||
"""Frame indicating user started speaking (interruption detected).
|
||||
"""Frame pushed to interrupt the pipeline.
|
||||
|
||||
This frame is used to interrupt the pipeline. For example, when a user
|
||||
starts speaking to cancel any in-progress bot output. It can also be pushed
|
||||
by any processor.
|
||||
|
||||
Parameters:
|
||||
event: Optional event set when the frame has fully traversed the
|
||||
pipeline.
|
||||
|
||||
Emitted by the BaseInputTransport to indicate that a user has started
|
||||
speaking (i.e. is interrupting). This is similar to
|
||||
UserStartedSpeakingFrame except that it should be pushed concurrently
|
||||
with other frames (so the order is not guaranteed).
|
||||
"""
|
||||
|
||||
pass
|
||||
event: Optional[asyncio.Event] = None
|
||||
|
||||
def complete(self):
|
||||
"""Signal that this interruption has been fully processed.
|
||||
|
||||
Called automatically when the frame reaches the pipeline sink, or
|
||||
manually when the frame is consumed before reaching it (e.g. when
|
||||
the user is muted).
|
||||
"""
|
||||
if self.event:
|
||||
self.event.set()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1706,15 +1721,19 @@ class StopTaskFrame(TaskFrame):
|
||||
|
||||
@dataclass
|
||||
class InterruptionTaskFrame(TaskFrame):
|
||||
"""Frame indicating the bot should be interrupted.
|
||||
"""Frame indicating the pipeline should be interrupted.
|
||||
|
||||
This frame should be pushed upstream to indicate the pipeline should be
|
||||
interrupted. The pipeline task converts this into an `InterruptionFrame` and
|
||||
sends it downstream. The `event` is passed to the `InterruptionFrame` so it
|
||||
can signal when the interruption has fully traversed the pipeline.
|
||||
|
||||
Parameters:
|
||||
event: Optional event passed to the corresponding `InterruptionFrame`.
|
||||
|
||||
Emitted when the bot should be interrupted. This will mainly cause the
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
event: Optional[asyncio.Event] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -864,7 +864,7 @@ class PipelineTask(BasePipelineTask):
|
||||
# pipeline. This is in case the push task is blocked waiting for a
|
||||
# pipeline-ending frame to finish traversing the pipeline.
|
||||
logger.debug(f"{self}: received interruption task frame {frame}")
|
||||
await self._pipeline.queue_frame(InterruptionFrame())
|
||||
await self._pipeline.queue_frame(InterruptionFrame(event=frame.event))
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
await self._call_event_handler("on_pipeline_error", frame)
|
||||
if frame.fatal:
|
||||
@@ -903,6 +903,8 @@ class PipelineTask(BasePipelineTask):
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
await self._heartbeat_queue.put(frame)
|
||||
|
||||
|
||||
@@ -552,6 +552,12 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if should_mute_frame:
|
||||
logger.trace(f"{frame.name} suppressed - user currently muted")
|
||||
|
||||
# When muted, the InterruptionFrame won't propagate further and
|
||||
# will never reach the pipeline sink. Complete it here so
|
||||
# push_interruption_task_frame_and_wait() doesn't hang.
|
||||
if should_mute_frame and isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
|
||||
should_mute_next_time = False
|
||||
for s in self._params.user_mute_strategies:
|
||||
should_mute_next_time |= await s.process_frame(frame)
|
||||
|
||||
@@ -234,6 +234,12 @@ class STTMuteFilter(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
logger.trace(f"{frame.__class__.__name__} suppressed - STT currently muted")
|
||||
|
||||
# When muted, the InterruptionFrame won't propagate further
|
||||
# and will never reach the pipeline sink. Complete it here so
|
||||
# push_interruption_task_frame_and_wait() doesn't hang.
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
frame.complete()
|
||||
else:
|
||||
# Pass all other frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -52,6 +52,8 @@ from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMet
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
INTERRUPTION_COMPLETION_TIMEOUT = 2.0
|
||||
|
||||
|
||||
class FrameDirection(Enum):
|
||||
"""Direction of frame flow in the processing pipeline.
|
||||
@@ -240,13 +242,9 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
self.__process_current_frame: Optional[Frame] = None
|
||||
|
||||
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
|
||||
# Then we wait for the corresponding `InterruptionFrame` to travel from
|
||||
# the start of the pipeline back to the processor that sent the
|
||||
# `InterruptionTaskFrame`. This wait is handled using the following
|
||||
# event.
|
||||
# Set while awaiting push_interruption_task_frame_and_wait() so that
|
||||
# _start_interruption() knows not to cancel the process task.
|
||||
self._wait_for_interruption = False
|
||||
self._wait_interruption_event = asyncio.Event()
|
||||
|
||||
# Frame processor events.
|
||||
self._register_event_handler("on_before_process_frame", sync=True)
|
||||
@@ -600,10 +598,11 @@ class FrameProcessor(BaseObject):
|
||||
if self._cancelling:
|
||||
return
|
||||
|
||||
# If we are waiting for an interruption we will bypass all queued system
|
||||
# frames and we will process the frame right away. This is because a
|
||||
# previous system frame might be waiting for the interruption frame and
|
||||
# it's blocking the input task.
|
||||
# If we are waiting for an interruption, bypass all queued system frames
|
||||
# and process the frame right away. This is because a previous system
|
||||
# frame might be waiting for the interruption frame blocking the input
|
||||
# task, so this InterruptionFrame would never be dequeued and we'd
|
||||
# deadlock.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
return
|
||||
@@ -742,31 +741,38 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self._call_event_handler("on_after_push_frame", frame)
|
||||
|
||||
# If we are waiting for an interruption and we get an interruption, then
|
||||
# we can unblock `push_interruption_task_frame_and_wait()`.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
self._wait_interruption_event.set()
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self):
|
||||
"""Push an interruption task frame upstream and wait for the interruption.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream to the pipeline
|
||||
task and waits to receive the corresponding `InterruptionFrame`. When
|
||||
the function finishes it is guaranteed that the `InterruptionFrame` has
|
||||
been pushed downstream.
|
||||
This function sends an `InterruptionTaskFrame` upstream to the
|
||||
pipeline task. The task creates a corresponding `InterruptionFrame`
|
||||
and sends it downstream through the pipeline. An `asyncio.Event` is
|
||||
attached to both frames so the caller can wait until the interruption
|
||||
has fully traversed the pipeline. The event is set when the
|
||||
`InterruptionFrame` reaches the pipeline sink. If the frame does
|
||||
not complete within `INTERRUPTION_COMPLETION_TIMEOUT` seconds, a
|
||||
warning is logged periodically until it completes.
|
||||
|
||||
"""
|
||||
self._wait_for_interruption = True
|
||||
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
event = asyncio.Event()
|
||||
|
||||
# Wait for an `InterruptionFrame` to come to this processor and be
|
||||
# pushed. Take a look at `push_frame()` to see how we first push the
|
||||
# `InterruptionFrame` and then we set the event in order to maintain
|
||||
# frame ordering.
|
||||
await self._wait_interruption_event.wait()
|
||||
await self.push_frame(InterruptionTaskFrame(event=event), FrameDirection.UPSTREAM)
|
||||
|
||||
# Clean the event.
|
||||
self._wait_interruption_event.clear()
|
||||
# Wait for the `InterruptionFrame` to complete and log a warning
|
||||
# periodically if it takes too long.
|
||||
while not event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=INTERRUPTION_COMPLETION_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"{self}: InterruptionFrame has not completed after"
|
||||
f" {INTERRUPTION_COMPLETION_TIMEOUT}s. Make sure"
|
||||
" InterruptionFrame.complete() is being called (e.g. if the"
|
||||
" frame is being blocked or consumed before reaching the"
|
||||
" pipeline sink)."
|
||||
)
|
||||
|
||||
self._wait_for_interruption = False
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
DataFrame,
|
||||
EndFrame,
|
||||
@@ -22,7 +24,11 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frame_processor import (
|
||||
INTERRUPTION_COMPLETION_TIMEOUT,
|
||||
FrameDirection,
|
||||
FrameProcessor,
|
||||
)
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
@@ -449,6 +455,109 @@ class TestFrameProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
stop_frames = [f for f in received_frames if isinstance(f, StopFrame)]
|
||||
self.assertEqual(len(stop_frames), 1, "StopFrame should survive interruption")
|
||||
|
||||
async def test_interruption_frame_complete_sets_event(self):
|
||||
"""Test that InterruptionFrame.complete() sets the event."""
|
||||
event = asyncio.Event()
|
||||
frame = InterruptionFrame(event=event)
|
||||
self.assertFalse(event.is_set())
|
||||
frame.complete()
|
||||
self.assertTrue(event.is_set())
|
||||
|
||||
async def test_interruption_frame_complete_without_event(self):
|
||||
"""Test that InterruptionFrame.complete() is safe without an event."""
|
||||
frame = InterruptionFrame()
|
||||
frame.complete() # Should not raise
|
||||
|
||||
async def test_interruption_event_set_at_pipeline_sink(self):
|
||||
"""Test that the event from push_interruption_task_frame_and_wait()
|
||||
is set when the InterruptionFrame reaches the pipeline sink."""
|
||||
event_was_set = False
|
||||
|
||||
class InterruptOnTextProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
nonlocal event_was_set
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
event_was_set = True
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
pipeline = Pipeline([InterruptOnTextProcessor()])
|
||||
|
||||
frames_to_send = [
|
||||
TextFrame(text="trigger"),
|
||||
]
|
||||
expected_down_frames = [
|
||||
InterruptionFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.assertTrue(event_was_set, "Event should be set after InterruptionFrame completes")
|
||||
|
||||
async def test_interruption_completion_timeout_warning(self):
|
||||
"""Test that a warning is logged when an InterruptionFrame is blocked
|
||||
and never reaches the pipeline sink."""
|
||||
warnings = []
|
||||
handler_id = logger.add(
|
||||
lambda msg: warnings.append(str(msg)), level="WARNING", format="{message}"
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
class BlockInterruptionProcessor(FrameProcessor):
|
||||
"""Blocks InterruptionFrames, completing them after a delay."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
# Complete after the timeout so the warning fires
|
||||
# but the test doesn't hang.
|
||||
async def delayed_complete():
|
||||
await asyncio.sleep(INTERRUPTION_COMPLETION_TIMEOUT + 1.0)
|
||||
frame.complete()
|
||||
|
||||
asyncio.create_task(delayed_complete())
|
||||
return
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class InterruptOnTextProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(OutputTransportMessageUrgentFrame(message="done"))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
pipeline = Pipeline([BlockInterruptionProcessor(), InterruptOnTextProcessor()])
|
||||
|
||||
frames_to_send = [
|
||||
TextFrame(text="trigger"),
|
||||
]
|
||||
expected_down_frames = [
|
||||
OutputTransportMessageUrgentFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
finally:
|
||||
logger.remove(handler_id)
|
||||
|
||||
self.assertTrue(
|
||||
any("InterruptionFrame has not completed" in w for w in warnings),
|
||||
"Expected a timeout warning about InterruptionFrame not completing",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -14,6 +15,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
@@ -327,6 +329,33 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_interruption_frame_completed_when_muted(self):
|
||||
"""Test that InterruptionFrame.complete() is called when the frame is
|
||||
suppressed due to muting, so push_interruption_task_frame_and_wait()
|
||||
doesn't hang."""
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
event = asyncio.Event()
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
InterruptionFrame(event=event),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
self.assertTrue(event.is_set(), "InterruptionFrame.complete() should be called when muted")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user