FrameProcessor: add push_interruption_task_frame_and_wait()
This commit is contained in:
11
CHANGELOG.md
11
CHANGELOG.md
@@ -9,6 +9,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `FrameProcessor.push_interruption_task_frame_and_wait()`. Use this
|
||||
method to programatically interrupt the bot from any part of the
|
||||
pipeline. This guarantees that all the processors in the pipeline are
|
||||
interrupted in order (from upstream to downstream). Internally, this works by
|
||||
first pushing an `InterruptionTaskFrame` upstream until it reaches the
|
||||
pipeline task. The pipeline task then generates an `InterruptionFrame`, which
|
||||
flows downstream through all processors. Once the `InterruptionFrame` has
|
||||
reaches the processor waiting for the interruption, the function returns and
|
||||
execution continues after the call. Think of it as sending an upstream request
|
||||
for interruption and waiting until the acknowledgment flows back downstream.
|
||||
|
||||
- Added new base `TaskFrame` (which is a system frame). This is the base class
|
||||
for all task frames (`EndTaskFrame`, `CancelTaskFrame`, etc.) that are meant
|
||||
to be pushed upstream to reach the pipeline task.
|
||||
|
||||
@@ -23,7 +23,6 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionTaskFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
@@ -360,7 +359,7 @@ class ClassificationProcessor(FrameProcessor):
|
||||
await self._voicemail_notifier.notify() # Clear buffered TTS frames
|
||||
|
||||
# Interrupt the current pipeline to stop any ongoing processing
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
# Set the voicemail event to trigger the voicemail handler
|
||||
self._voicemail_event.clear()
|
||||
|
||||
@@ -640,9 +640,12 @@ class PipelineTask(BasePipelineTask):
|
||||
logger.debug(f"{self}: received stop task frame {frame}")
|
||||
await self.queue_frame(StopFrame())
|
||||
elif isinstance(frame, InterruptionTaskFrame):
|
||||
# Tell the task we should interrupt the pipeline.
|
||||
# Tell the task we should interrupt the pipeline. Note that we are
|
||||
# bypassing the push queue and directly queue into the
|
||||
# pipeline. This is in case the push task is blocked waiting for a
|
||||
# pipeline-ending frame to finish traversing the pipeline.
|
||||
logger.debug(f"{self}: received interruption task frame {frame}")
|
||||
await self.queue_frame(InterruptionFrame())
|
||||
await self._pipeline.queue_frame(InterruptionFrame())
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
if frame.fatal:
|
||||
logger.error(f"A fatal error occurred: {frame}")
|
||||
|
||||
@@ -20,7 +20,6 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputDTMFFrame,
|
||||
InterruptionTaskFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
@@ -105,7 +104,7 @@ class DTMFAggregator(FrameProcessor):
|
||||
|
||||
# For first digit, schedule interruption.
|
||||
if is_first_digit:
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
# Check for immediate flush conditions
|
||||
if frame.button == self._termination_digit:
|
||||
|
||||
@@ -35,7 +35,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
@@ -531,9 +531,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing InterruptionTaskFrame and aggregation"
|
||||
"Interruption conditions met - pushing interruption and aggregation"
|
||||
)
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
|
||||
@@ -35,7 +35,6 @@ from pipecat.frames.frames import (
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -309,9 +308,9 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing InterruptionTaskFrame and aggregation"
|
||||
"Interruption conditions met - pushing interruption and aggregation"
|
||||
)
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
|
||||
@@ -29,6 +29,7 @@ from pipecat.frames.frames import (
|
||||
FrameProcessorResumeFrame,
|
||||
FrameProcessorResumeUrgentFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
)
|
||||
@@ -219,6 +220,9 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_event: Optional[asyncio.Event] = None
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._wait_for_interruption = False
|
||||
self._wait_interruption_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
"""Get the unique identifier for this processor.
|
||||
@@ -542,6 +546,14 @@ class FrameProcessor(BaseObject):
|
||||
if self._cancelling:
|
||||
return
|
||||
|
||||
# If we are waiting for an interruption we will bypass all queued system
|
||||
# frames and we will process the frame right away. This is because a
|
||||
# previous system frame might be waiting for the interruption frame and
|
||||
# it's blocking the input task.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
return
|
||||
|
||||
if self._enable_direct_mode:
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
else:
|
||||
@@ -620,6 +632,32 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._wait_interruption_event.set()
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self):
|
||||
"""Push an interruption task frame upstream and wait for the interruption.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream to the pipeline
|
||||
task and waits to receive the corresponding `InterruptionFrame`. When
|
||||
the function finishes it is guaranteed that the `InterruptionFrame` has
|
||||
been pushed downstream.
|
||||
"""
|
||||
self._wait_for_interruption = True
|
||||
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
# Wait for an `InterruptionFrame` to come to this processor and be
|
||||
# pushed. Take a look at `push_frame()` to see how we first push the
|
||||
# `InterruptionFrame` and then we set the event in order to maintain
|
||||
# frame ordering.
|
||||
await self._wait_interruption_event.wait()
|
||||
|
||||
# Clean the event.
|
||||
self._wait_interruption_event.clear()
|
||||
|
||||
self._wait_for_interruption = False
|
||||
|
||||
async def __start(self, frame: StartFrame):
|
||||
"""Handle the start frame to initialize processor state.
|
||||
|
||||
@@ -669,20 +707,22 @@ class FrameProcessor(BaseObject):
|
||||
async def _start_interruption(self):
|
||||
"""Start handling an interruption by cancelling current tasks."""
|
||||
try:
|
||||
# Cancel the process task. This will stop processing queued frames.
|
||||
await self.__cancel_process_task()
|
||||
if self._wait_for_interruption:
|
||||
# If we get here we know the process task was just waiting for
|
||||
# an interruption (push_interruption_task_frame_and_wait()), so
|
||||
# we can't cancel the task because it might still need to do
|
||||
# more things (e.g. pushing a frame after the
|
||||
# interruption). Instead we just drain the queue because this is
|
||||
# an interruption.
|
||||
self.__reset_process_task()
|
||||
else:
|
||||
# Cancel and re-create the process task including the queue.
|
||||
await self.__cancel_process_task()
|
||||
self.__create_process_task()
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
# Create a new process queue and task.
|
||||
self.__create_process_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
"""Stop handling an interruption."""
|
||||
# Nothing to do right now.
|
||||
pass
|
||||
|
||||
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Internal method to push frames to adjacent processors.
|
||||
|
||||
@@ -764,6 +804,17 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_queue = asyncio.Queue()
|
||||
self.__process_frame_task = self.create_task(self.__process_frame_task_handler())
|
||||
|
||||
def __reset_process_task(self):
|
||||
"""Reset non-system frame processing task."""
|
||||
if self._enable_direct_mode:
|
||||
return
|
||||
|
||||
self.__should_block_frames = False
|
||||
self.__process_event = asyncio.Event()
|
||||
while not self.__process_queue.empty():
|
||||
self.__process_queue.get_nowait()
|
||||
self.__process_queue.task_done()
|
||||
|
||||
async def __cancel_process_task(self):
|
||||
"""Cancel the non-system frame processing task."""
|
||||
if self.__process_frame_task:
|
||||
|
||||
@@ -41,7 +41,6 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -1206,7 +1205,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def interrupt_bot(self):
|
||||
"""Send a bot interruption frame upstream."""
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
async def send_server_message(self, data: Any):
|
||||
"""Send a server message to the client."""
|
||||
|
||||
@@ -716,14 +716,12 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self._start_interruption() # cancels this processor task
|
||||
await self.push_frame(InterruptionFrame()) # cancels downstream tasks
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._stop_interruption()
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):
|
||||
|
||||
@@ -658,14 +658,12 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
await self._start_interruption() # cancels this processor task
|
||||
await self.push_frame(InterruptionFrame()) # cancels downstream tasks
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._stop_interruption()
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):
|
||||
|
||||
@@ -24,7 +24,6 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -749,14 +748,13 @@ class SpeechmaticsSTTService(STTService):
|
||||
return
|
||||
|
||||
# Frames to send
|
||||
upstream_frames: list[Frame] = []
|
||||
downstream_frames: list[Frame] = []
|
||||
|
||||
# If VAD is enabled, then send a speaking frame
|
||||
if self._params.enable_vad and not self._is_speaking:
|
||||
logger.debug("User started speaking")
|
||||
self._is_speaking = True
|
||||
upstream_frames += [InterruptionTaskFrame()]
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
downstream_frames += [UserStartedSpeakingFrame()]
|
||||
|
||||
# If final, then re-parse into TranscriptionFrame
|
||||
@@ -794,10 +792,6 @@ class SpeechmaticsSTTService(STTService):
|
||||
self._is_speaking = False
|
||||
downstream_frames += [UserStoppedSpeakingFrame()]
|
||||
|
||||
# Send UPSTREAM frames
|
||||
for frame in upstream_frames:
|
||||
await self.push_frame(frame, FrameDirection.UPSTREAM)
|
||||
|
||||
# Send the DOWNSTREAM frames
|
||||
for frame in downstream_frames:
|
||||
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
@@ -32,8 +32,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputImageRawFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
MetricsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
@@ -353,11 +351,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
if should_push_immediate_interruption and self.interruptions_allowed:
|
||||
await self._start_interruption()
|
||||
# Push an out-of-band frame (i.e. not using the ordered push
|
||||
# frame task) to stop everything, specially at the output
|
||||
# transport.
|
||||
await self.push_frame(InterruptionFrame())
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
elif self.interruption_strategies and self._bot_speaking:
|
||||
logger.debug(
|
||||
"User started speaking while bot is speaking with interruption config - "
|
||||
@@ -372,9 +366,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.push_frame(downstream_frame)
|
||||
await self.push_frame(upstream_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
if self.interruptions_allowed:
|
||||
await self._stop_interruption()
|
||||
|
||||
#
|
||||
# Handle bot speaking state
|
||||
#
|
||||
|
||||
@@ -10,6 +10,7 @@ from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.dtmf_aggregator import DTMFAggregator
|
||||
@@ -28,6 +29,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
InputDTMFFrame,
|
||||
InputDTMFFrame,
|
||||
@@ -59,9 +61,11 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame, # First aggregation "12"
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame, # Second aggregation "3"
|
||||
]
|
||||
|
||||
@@ -93,10 +97,12 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame, # "12#"
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame, # "45"
|
||||
]
|
||||
@@ -125,6 +131,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame, # Should flush before EndFrame
|
||||
EndFrame,
|
||||
@@ -152,6 +159,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame,
|
||||
]
|
||||
@@ -178,6 +186,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
expected_down_frames = [
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
InputDTMFFrame,
|
||||
InputDTMFFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -214,7 +223,11 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
|
||||
# All the InputDTMFFrames plus one TranscriptionFrame
|
||||
expected_down_frames = [InputDTMFFrame] * len(frames_to_send) + [TranscriptionFrame]
|
||||
expected_down_frames = (
|
||||
[InputDTMFFrame, InterruptionFrame]
|
||||
+ [InputDTMFFrame] * (len(frames_to_send) - 1)
|
||||
+ [TranscriptionFrame]
|
||||
)
|
||||
|
||||
received_down_frames, _ = await run_test(
|
||||
aggregator,
|
||||
|
||||
Reference in New Issue
Block a user