Compare commits

...

2 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
901899aa19 PipelineTask: handle BotInterruptionFrame in PipelineTask
This allows having processors before the input transport or even not having an
input transport.
2025-08-27 19:41:40 -07:00
Aleix Conchillo Flaqué
947faf8a39 tests: allow PipelineParams instead of individual arguments 2025-08-27 19:04:53 -07:00
6 changed files with 26 additions and 17 deletions

View File

@@ -97,6 +97,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- `BotInterruptionFrame` frame handling has been moved from `BaseInputTransport`
to `PipelineTask`. This allows any type of pipeline to handle
`BotInterruptionFrame` without the need of an input transport.
- `pipeline.tests.utils.run_test()` now allows passing `PipelineParams` instead
of individual parameters.
- Updated `daily-python` to 0.19.8.
- `PipelineTask` now waits for `StartFrame` to reach the end of the pipeline

View File

@@ -23,6 +23,7 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
from pipecat.clocks.base_clock import BaseClock
from pipecat.clocks.system_clock import SystemClock
from pipecat.frames.frames import (
BotInterruptionFrame,
BotSpeakingFrame,
CancelFrame,
CancelTaskFrame,
@@ -36,6 +37,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
MetricsFrame,
StartFrame,
StartInterruptionFrame,
StopFrame,
StopTaskFrame,
TranscriptionFrame,
@@ -632,7 +634,11 @@ class PipelineTask(BasePipelineTask):
if isinstance(frame, self._reached_upstream_types):
await self._call_event_handler("on_frame_reached_upstream", frame)
if isinstance(frame, EndTaskFrame):
if isinstance(frame, BotInterruptionFrame) and self.params.allow_interruptions:
# Tell the pipeline we should interrupt.
logger.debug("Bot interruption")
await self.queue_frame(StartInterruptionFrame())
elif isinstance(frame, EndTaskFrame):
# Tell the task we should end nicely.
await self.queue_frame(EndFrame())
elif isinstance(frame, CancelTaskFrame):

View File

@@ -8,7 +8,7 @@
import asyncio
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple
from typing import Awaitable, Callable, List, Optional, Sequence, Tuple
from pipecat.frames.frames import (
EndFrame,
@@ -128,7 +128,7 @@ async def run_test(
expected_up_frames: Optional[Sequence[type]] = None,
ignore_start: bool = True,
observers: Optional[List[BaseObserver]] = None,
start_metadata: Optional[Dict[str, Any]] = None,
pipeline_params: Optional[PipelineParams] = None,
send_end_frame: bool = True,
) -> Tuple[Sequence[Frame], Sequence[Frame]]:
"""Run a test pipeline with the specified processor and validate frame flow.
@@ -144,7 +144,7 @@ async def run_test(
expected_up_frames: Expected frame types flowing upstream (optional).
ignore_start: Whether to ignore StartFrames in frame validation.
observers: Optional list of observers to attach to the pipeline.
start_metadata: Optional metadata to include with the StartFrame.
pipeline_params: Pipeline parameters.
send_end_frame: Whether to send an EndFrame at the end of the test.
Returns:
@@ -154,7 +154,6 @@ async def run_test(
AssertionError: If the received frames don't match the expected frame types.
"""
observers = observers or []
start_metadata = start_metadata or {}
received_up = asyncio.Queue()
received_down = asyncio.Queue()
@@ -173,7 +172,7 @@ async def run_test(
task = PipelineTask(
pipeline,
params=PipelineParams(start_metadata=start_metadata),
params=pipeline_params,
observers=observers,
cancel_on_idle_timeout=False,
)

View File

@@ -22,7 +22,6 @@ from pipecat.audio.turn.base_turn_analyzer import (
)
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -289,8 +288,6 @@ class BaseInputTransport(FrameProcessor):
elif isinstance(frame, CancelFrame):
await self.cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotInterruptionFrame):
await self._handle_bot_interruption(frame)
elif isinstance(frame, BotStartedSpeakingFrame):
await self._handle_bot_started_speaking(frame)
await self.push_frame(frame, direction)
@@ -335,13 +332,6 @@ class BaseInputTransport(FrameProcessor):
# Handle interruptions
#
async def _handle_bot_interruption(self, frame: BotInterruptionFrame):
"""Handle bot interruption frames."""
logger.debug("Bot interruption")
if self.interruptions_allowed:
await self._start_interruption()
await self.push_frame(StartInterruptionFrame())
async def _handle_user_interruption(self, frame: Frame):
"""Handle user interruption events based on speaking state."""
if isinstance(frame, UserStartedSpeakingFrame):

View File

@@ -12,6 +12,7 @@ from pipecat.frames.frames import (
KeypadEntry,
TranscriptionFrame,
)
from pipecat.pipeline.task import PipelineParams
from pipecat.processors.aggregators.dtmf_aggregator import DTMFAggregator
from pipecat.tests.utils import SleepFrame, run_test
@@ -69,6 +70,8 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
# TODO(aleix): we should handle StartInterruptionFrame
pipeline_params=PipelineParams(allow_interruptions=False),
)
# Find the TranscriptionFrames
@@ -105,6 +108,8 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
# TODO(aleix): we should handle StartInterruptionFrame
pipeline_params=PipelineParams(allow_interruptions=False),
)
transcription_frames = [
@@ -134,6 +139,8 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
# TODO(aleix): we should handle StartInterruptionFrame
pipeline_params=PipelineParams(allow_interruptions=False),
send_end_frame=False, # We're sending one in the test to test EndFrame logic
)

View File

@@ -65,7 +65,7 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase):
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
ignore_start=False,
start_metadata={"foo": "bar"},
pipeline_params=PipelineParams(start_metadata={"foo": "bar"}),
)
assert "foo" in received_down[-1].metadata