Compare commits
2 Commits
hush/realt
...
aleix/pipe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
901899aa19 | ||
|
|
947faf8a39 |
@@ -97,6 +97,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- `BotInterruptionFrame` frame handling has been moved from `BaseInputTransport`
|
||||
to `PipelineTask`. This allows any type of pipeline to handle
|
||||
`BotInterruptionFrame` without the need of an input transport.
|
||||
|
||||
- `pipeline.tests.utils.run_test()` now allows passing `PipelineParams` instead
|
||||
of individual parameters.
|
||||
|
||||
- Updated `daily-python` to 0.19.8.
|
||||
|
||||
- `PipelineTask` now waits for `StartFrame` to reach the end of the pipeline
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotSpeakingFrame,
|
||||
CancelFrame,
|
||||
CancelTaskFrame,
|
||||
@@ -36,6 +37,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopFrame,
|
||||
StopTaskFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -632,7 +634,11 @@ class PipelineTask(BasePipelineTask):
|
||||
if isinstance(frame, self._reached_upstream_types):
|
||||
await self._call_event_handler("on_frame_reached_upstream", frame)
|
||||
|
||||
if isinstance(frame, EndTaskFrame):
|
||||
if isinstance(frame, BotInterruptionFrame) and self.params.allow_interruptions:
|
||||
# Tell the pipeline we should interrupt.
|
||||
logger.debug("Bot interruption")
|
||||
await self.queue_frame(StartInterruptionFrame())
|
||||
elif isinstance(frame, EndTaskFrame):
|
||||
# Tell the task we should end nicely.
|
||||
await self.queue_frame(EndFrame())
|
||||
elif isinstance(frame, CancelTaskFrame):
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Awaitable, Callable, List, Optional, Sequence, Tuple
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
@@ -128,7 +128,7 @@ async def run_test(
|
||||
expected_up_frames: Optional[Sequence[type]] = None,
|
||||
ignore_start: bool = True,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
start_metadata: Optional[Dict[str, Any]] = None,
|
||||
pipeline_params: Optional[PipelineParams] = None,
|
||||
send_end_frame: bool = True,
|
||||
) -> Tuple[Sequence[Frame], Sequence[Frame]]:
|
||||
"""Run a test pipeline with the specified processor and validate frame flow.
|
||||
@@ -144,7 +144,7 @@ async def run_test(
|
||||
expected_up_frames: Expected frame types flowing upstream (optional).
|
||||
ignore_start: Whether to ignore StartFrames in frame validation.
|
||||
observers: Optional list of observers to attach to the pipeline.
|
||||
start_metadata: Optional metadata to include with the StartFrame.
|
||||
pipeline_params: Pipeline parameters.
|
||||
send_end_frame: Whether to send an EndFrame at the end of the test.
|
||||
|
||||
Returns:
|
||||
@@ -154,7 +154,6 @@ async def run_test(
|
||||
AssertionError: If the received frames don't match the expected frame types.
|
||||
"""
|
||||
observers = observers or []
|
||||
start_metadata = start_metadata or {}
|
||||
|
||||
received_up = asyncio.Queue()
|
||||
received_down = asyncio.Queue()
|
||||
@@ -173,7 +172,7 @@ async def run_test(
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(start_metadata=start_metadata),
|
||||
params=pipeline_params,
|
||||
observers=observers,
|
||||
cancel_on_idle_timeout=False,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ from pipecat.audio.turn.base_turn_analyzer import (
|
||||
)
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -289,8 +288,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
await self._handle_bot_interruption(frame)
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._handle_bot_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -335,13 +332,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Handle interruptions
|
||||
#
|
||||
|
||||
async def _handle_bot_interruption(self, frame: BotInterruptionFrame):
|
||||
"""Handle bot interruption frames."""
|
||||
logger.debug("Bot interruption")
|
||||
if self.interruptions_allowed:
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
|
||||
async def _handle_user_interruption(self, frame: Frame):
|
||||
"""Handle user interruption events based on speaking state."""
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
|
||||
@@ -12,6 +12,7 @@ from pipecat.frames.frames import (
|
||||
KeypadEntry,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineParams
|
||||
from pipecat.processors.aggregators.dtmf_aggregator import DTMFAggregator
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
@@ -69,6 +70,8 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
# TODO(aleix): we should handle StartInterruptionFrame
|
||||
pipeline_params=PipelineParams(allow_interruptions=False),
|
||||
)
|
||||
|
||||
# Find the TranscriptionFrames
|
||||
@@ -105,6 +108,8 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
# TODO(aleix): we should handle StartInterruptionFrame
|
||||
pipeline_params=PipelineParams(allow_interruptions=False),
|
||||
)
|
||||
|
||||
transcription_frames = [
|
||||
@@ -134,6 +139,8 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
# TODO(aleix): we should handle StartInterruptionFrame
|
||||
pipeline_params=PipelineParams(allow_interruptions=False),
|
||||
send_end_frame=False, # We're sending one in the test to test EndFrame logic
|
||||
)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
ignore_start=False,
|
||||
start_metadata={"foo": "bar"},
|
||||
pipeline_params=PipelineParams(start_metadata={"foo": "bar"}),
|
||||
)
|
||||
assert "foo" in received_down[-1].metadata
|
||||
|
||||
|
||||
Reference in New Issue
Block a user