FrameProcessor: add push_interruption_task_frame_and_wait()

This commit is contained in:
Aleix Conchillo Flaqué
2025-09-08 19:57:44 -07:00
parent 8249b014f0
commit 0b21f8a6bd
13 changed files with 103 additions and 48 deletions

View File

@@ -9,6 +9,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `FrameProcessor.push_interruption_task_frame_and_wait()`. Use this
method to programatically interrupt the bot from any part of the
pipeline. This guarantees that all the processors in the pipeline are
interrupted in order (from upstream to downstream). Internally, this works by
first pushing an `InterruptionTaskFrame` upstream until it reaches the
pipeline task. The pipeline task then generates an `InterruptionFrame`, which
flows downstream through all processors. Once the `InterruptionFrame` has
reaches the processor waiting for the interruption, the function returns and
execution continues after the call. Think of it as sending an upstream request
for interruption and waiting until the acknowledgment flows back downstream.
- Added new base `TaskFrame` (which is a system frame). This is the base class
for all task frames (`EndTaskFrame`, `CancelTaskFrame`, etc.) that are meant
to be pushed upstream to reach the pipeline task.

View File

@@ -23,7 +23,6 @@ from loguru import logger
from pipecat.frames.frames import (
EndFrame,
Frame,
InterruptionTaskFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
@@ -360,7 +359,7 @@ class ClassificationProcessor(FrameProcessor):
await self._voicemail_notifier.notify() # Clear buffered TTS frames
# Interrupt the current pipeline to stop any ongoing processing
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
await self.push_interruption_task_frame_and_wait()
# Set the voicemail event to trigger the voicemail handler
self._voicemail_event.clear()

View File

@@ -640,9 +640,12 @@ class PipelineTask(BasePipelineTask):
logger.debug(f"{self}: received stop task frame {frame}")
await self.queue_frame(StopFrame())
elif isinstance(frame, InterruptionTaskFrame):
# Tell the task we should interrupt the pipeline.
# Tell the task we should interrupt the pipeline. Note that we are
# bypassing the push queue and directly queue into the
# pipeline. This is in case the push task is blocked waiting for a
# pipeline-ending frame to finish traversing the pipeline.
logger.debug(f"{self}: received interruption task frame {frame}")
await self.queue_frame(InterruptionFrame())
await self._pipeline.queue_frame(InterruptionFrame())
elif isinstance(frame, ErrorFrame):
if frame.fatal:
logger.error(f"A fatal error occurred: {frame}")

View File

@@ -20,7 +20,6 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputDTMFFrame,
InterruptionTaskFrame,
StartFrame,
TranscriptionFrame,
)
@@ -105,7 +104,7 @@ class DTMFAggregator(FrameProcessor):
# For first digit, schedule interruption.
if is_first_digit:
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
await self.push_interruption_task_frame_and_wait()
# Check for immediate flush conditions
if frame.button == self._termination_digit:

View File

@@ -35,7 +35,7 @@ from pipecat.frames.frames import (
FunctionCallsStartedFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
InterruptionTaskFrame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
@@ -531,9 +531,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
if should_interrupt:
logger.debug(
"Interruption conditions met - pushing InterruptionTaskFrame and aggregation"
"Interruption conditions met - pushing interruption and aggregation"
)
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
await self.push_interruption_task_frame_and_wait()
await self._process_aggregation()
else:
logger.debug("Interruption conditions not met - not pushing aggregation")

View File

@@ -35,7 +35,6 @@ from pipecat.frames.frames import (
FunctionCallsStartedFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
InterruptionTaskFrame,
LLMContextAssistantTimestampFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
@@ -309,9 +308,9 @@ class LLMUserAggregator(LLMContextAggregator):
if should_interrupt:
logger.debug(
"Interruption conditions met - pushing InterruptionTaskFrame and aggregation"
"Interruption conditions met - pushing interruption and aggregation"
)
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
await self.push_interruption_task_frame_and_wait()
await self._process_aggregation()
else:
logger.debug("Interruption conditions not met - not pushing aggregation")

View File

@@ -29,6 +29,7 @@ from pipecat.frames.frames import (
FrameProcessorResumeFrame,
FrameProcessorResumeUrgentFrame,
InterruptionFrame,
InterruptionTaskFrame,
StartFrame,
SystemFrame,
)
@@ -219,6 +220,9 @@ class FrameProcessor(BaseObject):
self.__process_event: Optional[asyncio.Event] = None
self.__process_frame_task: Optional[asyncio.Task] = None
self._wait_for_interruption = False
self._wait_interruption_event = asyncio.Event()
@property
def id(self) -> int:
"""Get the unique identifier for this processor.
@@ -542,6 +546,14 @@ class FrameProcessor(BaseObject):
if self._cancelling:
return
# If we are waiting for an interruption we will bypass all queued system
# frames and we will process the frame right away. This is because a
# previous system frame might be waiting for the interruption frame and
# it's blocking the input task.
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
await self.__process_frame(frame, direction, callback)
return
if self._enable_direct_mode:
await self.__process_frame(frame, direction, callback)
else:
@@ -620,6 +632,32 @@ class FrameProcessor(BaseObject):
await self.__internal_push_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
self._wait_interruption_event.set()
async def push_interruption_task_frame_and_wait(self):
"""Push an interruption task frame upstream and wait for the interruption.
This function sends an `InterruptionTaskFrame` upstream to the pipeline
task and waits to receive the corresponding `InterruptionFrame`. When
the function finishes it is guaranteed that the `InterruptionFrame` has
been pushed downstream.
"""
self._wait_for_interruption = True
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
# Wait for an `InterruptionFrame` to come to this processor and be
# pushed. Take a look at `push_frame()` to see how we first push the
# `InterruptionFrame` and then we set the event in order to maintain
# frame ordering.
await self._wait_interruption_event.wait()
# Clean the event.
self._wait_interruption_event.clear()
self._wait_for_interruption = False
async def __start(self, frame: StartFrame):
"""Handle the start frame to initialize processor state.
@@ -669,20 +707,22 @@ class FrameProcessor(BaseObject):
async def _start_interruption(self):
"""Start handling an interruption by cancelling current tasks."""
try:
# Cancel the process task. This will stop processing queued frames.
await self.__cancel_process_task()
if self._wait_for_interruption:
# If we get here we know the process task was just waiting for
# an interruption (push_interruption_task_frame_and_wait()), so
# we can't cancel the task because it might still need to do
# more things (e.g. pushing a frame after the
# interruption). Instead we just drain the queue because this is
# an interruption.
self.__reset_process_task()
else:
# Cancel and re-create the process task including the queue.
await self.__cancel_process_task()
self.__create_process_task()
except Exception as e:
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
await self.push_error(ErrorFrame(str(e)))
# Create a new process queue and task.
self.__create_process_task()
async def _stop_interruption(self):
"""Stop handling an interruption."""
# Nothing to do right now.
pass
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
"""Internal method to push frames to adjacent processors.
@@ -764,6 +804,17 @@ class FrameProcessor(BaseObject):
self.__process_queue = asyncio.Queue()
self.__process_frame_task = self.create_task(self.__process_frame_task_handler())
def __reset_process_task(self):
"""Reset non-system frame processing task."""
if self._enable_direct_mode:
return
self.__should_block_frames = False
self.__process_event = asyncio.Event()
while not self.__process_queue.empty():
self.__process_queue.get_nowait()
self.__process_queue.task_done()
async def __cancel_process_task(self):
"""Cancel the non-system frame processing task."""
if self.__process_frame_task:

View File

@@ -41,7 +41,6 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
InterruptionTaskFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@@ -1206,7 +1205,7 @@ class RTVIProcessor(FrameProcessor):
async def interrupt_bot(self):
"""Send a bot interruption frame upstream."""
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
await self.push_interruption_task_frame_and_wait()
async def send_server_message(self, data: Any):
"""Send a server message to the client."""

View File

@@ -716,14 +716,12 @@ class OpenAIRealtimeLLMService(LLMService):
async def _handle_evt_speech_started(self, evt):
await self._truncate_current_audio_response()
await self._start_interruption() # cancels this processor task
await self.push_frame(InterruptionFrame()) # cancels downstream tasks
await self.push_interruption_task_frame_and_wait()
await self.push_frame(UserStartedSpeakingFrame())
async def _handle_evt_speech_stopped(self, evt):
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._stop_interruption()
await self.push_frame(UserStoppedSpeakingFrame())
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):

View File

@@ -658,14 +658,12 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _handle_evt_speech_started(self, evt):
await self._truncate_current_audio_response()
await self._start_interruption() # cancels this processor task
await self.push_frame(InterruptionFrame()) # cancels downstream tasks
await self.push_interruption_task_frame_and_wait()
await self.push_frame(UserStartedSpeakingFrame())
async def _handle_evt_speech_stopped(self, evt):
await self.start_ttfb_metrics()
await self.start_processing_metrics()
await self._stop_interruption()
await self.push_frame(UserStoppedSpeakingFrame())
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):

View File

@@ -24,7 +24,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
InterimTranscriptionFrame,
InterruptionTaskFrame,
StartFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
@@ -749,14 +748,13 @@ class SpeechmaticsSTTService(STTService):
return
# Frames to send
upstream_frames: list[Frame] = []
downstream_frames: list[Frame] = []
# If VAD is enabled, then send a speaking frame
if self._params.enable_vad and not self._is_speaking:
logger.debug("User started speaking")
self._is_speaking = True
upstream_frames += [InterruptionTaskFrame()]
await self.push_interruption_task_frame_and_wait()
downstream_frames += [UserStartedSpeakingFrame()]
# If final, then re-parse into TranscriptionFrame
@@ -794,10 +792,6 @@ class SpeechmaticsSTTService(STTService):
self._is_speaking = False
downstream_frames += [UserStoppedSpeakingFrame()]
# Send UPSTREAM frames
for frame in upstream_frames:
await self.push_frame(frame, FrameDirection.UPSTREAM)
# Send the DOWNSTREAM frames
for frame in downstream_frames:
await self.push_frame(frame, FrameDirection.DOWNSTREAM)

View File

@@ -32,8 +32,6 @@ from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
InputImageRawFrame,
InterruptionFrame,
InterruptionTaskFrame,
MetricsFrame,
SpeechControlParamsFrame,
StartFrame,
@@ -353,11 +351,7 @@ class BaseInputTransport(FrameProcessor):
# Make sure we notify about interruptions quickly out-of-band.
if should_push_immediate_interruption and self.interruptions_allowed:
await self._start_interruption()
# Push an out-of-band frame (i.e. not using the ordered push
# frame task) to stop everything, specially at the output
# transport.
await self.push_frame(InterruptionFrame())
await self.push_interruption_task_frame_and_wait()
elif self.interruption_strategies and self._bot_speaking:
logger.debug(
"User started speaking while bot is speaking with interruption config - "
@@ -372,9 +366,6 @@ class BaseInputTransport(FrameProcessor):
await self.push_frame(downstream_frame)
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
if self.interruptions_allowed:
await self._stop_interruption()
#
# Handle bot speaking state
#

View File

@@ -10,6 +10,7 @@ from pipecat.audio.dtmf.types import KeypadEntry
from pipecat.frames.frames import (
EndFrame,
InputDTMFFrame,
InterruptionFrame,
TranscriptionFrame,
)
from pipecat.processors.aggregators.dtmf_aggregator import DTMFAggregator
@@ -28,6 +29,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
InputDTMFFrame,
InputDTMFFrame,
@@ -59,9 +61,11 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
TranscriptionFrame, # First aggregation "12"
InputDTMFFrame,
InterruptionFrame,
TranscriptionFrame, # Second aggregation "3"
]
@@ -93,10 +97,12 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
InputDTMFFrame,
TranscriptionFrame, # "12#"
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
TranscriptionFrame, # "45"
]
@@ -125,6 +131,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
TranscriptionFrame, # Should flush before EndFrame
EndFrame,
@@ -152,6 +159,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
TranscriptionFrame,
]
@@ -178,6 +186,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
expected_down_frames = [
InputDTMFFrame,
InterruptionFrame,
InputDTMFFrame,
InputDTMFFrame,
TranscriptionFrame,
@@ -214,7 +223,11 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
# All the InputDTMFFrames plus one TranscriptionFrame
expected_down_frames = [InputDTMFFrame] * len(frames_to_send) + [TranscriptionFrame]
expected_down_frames = (
[InputDTMFFrame, InterruptionFrame]
+ [InputDTMFFrame] * (len(frames_to_send) - 1)
+ [TranscriptionFrame]
)
received_down_frames, _ = await run_test(
aggregator,