add FrameProcessor.setup() to setup processors before StartFrame

This commit is contained in:
Aleix Conchillo Flaqué
2025-05-12 10:06:27 -07:00
parent 5290161ac4
commit 175f352ea7
9 changed files with 149 additions and 83 deletions

View File

@@ -9,6 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- A new function `FrameProcessor.setup()` has been added to allow setting up
frame processors before receiving a `StartFrame`. This is what's happening
internally: `FrameProcessor.setup()` is called, `StartFrame` is pushed from
the beginning of the pipeline, your regular pipeline operations, `EndFrame` or
`CancelFrame` are pushed from the beginning of the pipeline and finally
`FrameProcessor.cleanup()` is called.
- Allow passing observers to `run_test()` while running unit tests.
### Changed

View File

@@ -7,7 +7,6 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -20,16 +19,11 @@ from typing import (
)
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.clocks.base_clock import BaseClock
from pipecat.metrics.metrics import MetricsData
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio import BaseTaskManager
from pipecat.utils.time import nanoseconds_to_str
from pipecat.utils.utils import obj_count, obj_id
if TYPE_CHECKING:
from pipecat.observers.base_observer import BaseObserver
class KeypadEntry(str, Enum):
"""DTMF entries."""
@@ -447,14 +441,11 @@ class OutputDTMFFrame(DTMFFrame):
class StartFrame(SystemFrame):
"""This is the first frame that should be pushed down a pipeline."""
clock: BaseClock
task_manager: BaseTaskManager
audio_in_sample_rate: int = 16000
audio_out_sample_rate: int = 24000
allow_interruptions: bool = False
enable_metrics: bool = False
enable_usage_metrics: bool = False
observer: Optional["BaseObserver"] = None
report_only_initial_ttfb: bool = False

View File

@@ -20,7 +20,7 @@ from pipecat.frames.frames import (
)
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
class ParallelPipelineSource(FrameProcessor):
@@ -118,6 +118,12 @@ class ParallelPipeline(BasePipeline):
# Frame processor
#
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await asyncio.gather(*[s.setup(setup) for s in self._sources])
await asyncio.gather(*[p.setup(setup) for p in self._pipelines])
await asyncio.gather(*[s.setup(setup) for s in self._sinks])
async def cleanup(self):
await super().cleanup()
await asyncio.gather(*[s.cleanup() for s in self._sources])

View File

@@ -8,7 +8,7 @@ from typing import Callable, Coroutine, List
from pipecat.frames.frames import Frame
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
class PipelineSource(FrameProcessor):
@@ -70,6 +70,10 @@ class Pipeline(BasePipeline):
# Frame processor
#
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await self._setup_processors(setup)
async def cleanup(self):
await super().cleanup()
await self._cleanup_processors()
@@ -82,6 +86,10 @@ class Pipeline(BasePipeline):
elif direction == FrameDirection.UPSTREAM:
await self._sink.queue_frame(frame, FrameDirection.UPSTREAM)
async def _setup_processors(self, setup: FrameProcessorSetup):
for p in self._processors:
await p.setup(setup)
async def _cleanup_processors(self):
for p in self._processors:
await p.cleanup()

View File

@@ -14,7 +14,7 @@ from loguru import logger
from pipecat.frames.frames import ControlFrame, EndFrame, Frame, SystemFrame
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
@dataclass
@@ -103,6 +103,12 @@ class SyncParallelPipeline(BasePipeline):
# Frame processor
#
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await asyncio.gather(*[s["processor"].setup(setup) for s in self._sources])
await asyncio.gather(*[p.setup(setup) for p in self._pipelines])
await asyncio.gather(*[s["processor"].setup(setup) for s in self._sinks])
async def cleanup(self):
await super().cleanup()
await asyncio.gather(*[s["processor"].cleanup() for s in self._sources])

View File

@@ -33,7 +33,7 @@ from pipecat.observers.base_observer import BaseObserver
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.base_task import BaseTask
from pipecat.pipeline.task_observer import TaskObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
HEARTBEAT_SECONDS = 1.0
@@ -294,8 +294,15 @@ class PipelineTask(BaseTask):
return
cleanup_pipeline = True
try:
# Setup processors.
await self._setup()
# Create all main tasks and wait of the main push task. This is the
# task that pushes frames to the very beginning of our pipeline (our
# controlled PipelineTaskSource processor).
push_task = await self._create_tasks()
await self._task_manager.wait_for_task(push_task)
# We have already cleaned up the pipeline inside the task.
cleanup_pipeline = False
except asyncio.CancelledError:
@@ -405,6 +412,16 @@ class PipelineTask(BaseTask):
await self._pipeline_end_event.wait()
self._pipeline_end_event.clear()
async def _setup(self):
setup = FrameProcessorSetup(
clock=self._clock,
task_manager=self._task_manager,
observer=self._observer,
)
await self._source.setup(setup)
await self._pipeline.setup(setup)
await self._sink.setup(setup)
async def _cleanup(self, cleanup_pipeline: bool):
# Cleanup base object.
await self.cleanup()
@@ -427,14 +444,11 @@ class PipelineTask(BaseTask):
self._maybe_start_idle_task()
start_frame = StartFrame(
clock=self._clock,
task_manager=self._task_manager,
allow_interruptions=self._params.allow_interruptions,
audio_in_sample_rate=self._params.audio_in_sample_rate,
audio_out_sample_rate=self._params.audio_out_sample_rate,
enable_metrics=self._params.enable_metrics,
enable_usage_metrics=self._params.enable_usage_metrics,
observer=self._observer,
report_only_initial_ttfb=self._params.report_only_initial_ttfb,
)
start_frame.metadata = self._params.start_metadata

View File

@@ -5,6 +5,7 @@
#
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable, Coroutine, Optional
@@ -21,7 +22,7 @@ from pipecat.frames.frames import (
SystemFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
from pipecat.observers.base_observer import FramePushed
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
from pipecat.utils.asyncio import BaseTaskManager
from pipecat.utils.base_object import BaseObject
@@ -32,6 +33,13 @@ class FrameDirection(Enum):
UPSTREAM = 2
@dataclass
class FrameProcessorSetup:
clock: BaseClock
task_manager: BaseTaskManager
observer: Optional[BaseObserver] = None
class FrameProcessor(BaseObject):
def __init__(
self,
@@ -51,12 +59,17 @@ class FrameProcessor(BaseObject):
# Task Manager
self._task_manager: Optional[BaseTaskManager] = None
# Observer
self._observer: Optional[BaseObserver] = None
# Other properties
self._allow_interruptions = False
self._enable_metrics = False
self._enable_usage_metrics = False
self._report_only_initial_ttfb = False
self._observer = None
# Indicates whether we have received the StartFrame.
self.__started = False
# Cancellation is done through CancelFrame (a system frame). This could
# cause other events being triggered (e.g. closing a transport) which
@@ -167,6 +180,11 @@ class FrameProcessor(BaseObject):
raise Exception(f"{self} TaskManager is still not initialized.")
await self._task_manager.wait_for_task(task, timeout)
async def setup(self, setup: FrameProcessorSetup):
self._clock = setup.clock
self._task_manager = setup.task_manager
self._observer = setup.observer
async def cleanup(self):
await super().cleanup()
await self.__cancel_input_task()
@@ -227,13 +245,6 @@ class FrameProcessor(BaseObject):
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, StartFrame):
self._clock = frame.clock
self._task_manager = frame.task_manager
self._allow_interruptions = frame.allow_interruptions
self._enable_metrics = frame.enable_metrics
self._enable_usage_metrics = frame.enable_usage_metrics
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
self._observer = frame.observer
await self.__start(frame)
elif isinstance(frame, StartInterruptionFrame):
await self._start_interruption()
@@ -247,7 +258,7 @@ class FrameProcessor(BaseObject):
await self.push_frame(error, FrameDirection.UPSTREAM)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
if not self._check_ready(frame):
if not self._check_started(frame):
return
if isinstance(frame, SystemFrame):
@@ -256,6 +267,11 @@ class FrameProcessor(BaseObject):
await self.__push_queue.put((frame, direction))
async def __start(self, frame: StartFrame):
self.__started = True
self._allow_interruptions = frame.allow_interruptions
self._enable_metrics = frame.enable_metrics
self._enable_usage_metrics = frame.enable_usage_metrics
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
self.__create_input_task()
self.__create_push_task()
@@ -323,15 +339,10 @@ class FrameProcessor(BaseObject):
await self.push_error(ErrorFrame(str(e)))
raise
def _check_ready(self, frame: Frame):
# If we are trying to push a frame but we still have no clock, it means
# we didn't process a StartFrame.
if not self._clock:
logger.error(
f"{self} not properly initialized, missing super().process_frame(frame, direction)?"
)
return False
return True
def _check_started(self, frame: Frame):
if not self.__started:
logger.error(f"{self} Trying to process {frame} but StartFrame not received yet")
return self.__started
def __create_input_task(self):
if not self.__input_frame_task:

View File

@@ -34,7 +34,7 @@ from pipecat.frames.frames import (
UserImageRawFrame,
UserImageRequestFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
from pipecat.transcriptions.language import Language
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
@@ -410,7 +410,25 @@ class DailyTransportClient(EventHandler):
if not destination and self._camera:
self._camera.write_frame(frame.image)
async def setup(self, frame: StartFrame):
async def setup(self, setup: FrameProcessorSetup):
if self._task_manager:
return
self._task_manager = setup.task_manager
self._callback_task = self._task_manager.create_task(
self._callback_task_handler(),
f"{self}::callback_task",
)
async def cleanup(self):
if self._callback_task and self._task_manager:
await self._task_manager.cancel_task(self._callback_task)
self._callback_task = None
# Make sure we don't block the event loop in case `client.release()`
# takes extra time.
await self._get_event_loop().run_in_executor(self._executor, self._cleanup)
async def start(self, frame: StartFrame):
self._in_sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate
self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate
@@ -439,13 +457,6 @@ class DailyTransportClient(EventHandler):
)
Daily.select_speaker_device(self._speaker_name())
if not self._task_manager:
self._task_manager = frame.task_manager
self._callback_task = self._task_manager.create_task(
self._callback_task_handler(),
f"{self}::callback_task",
)
async def join(self):
# Transport already joined or joining, ignore.
if self._joined or self._joining:
@@ -612,14 +623,6 @@ class DailyTransportClient(EventHandler):
self._client.leave(completion=completion_callback(future))
return await asyncio.wait_for(future, timeout=10)
async def cleanup(self):
if self._callback_task and self._task_manager:
await self._task_manager.cancel_task(self._callback_task)
self._callback_task = None
# Make sure we don't block the event loop in case `client.release()`
# takes extra time.
await self._get_event_loop().run_in_executor(self._executor, self._cleanup)
def _cleanup(self):
if self._client:
self._client.release()
@@ -952,6 +955,15 @@ class DailyInputTransport(BaseInputTransport):
logger.debug(f"Start receiving audio")
self._audio_in_task = self.create_task(self._audio_in_task_handler())
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await self._client.setup(setup)
async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
await self._transport.cleanup()
async def start(self, frame: StartFrame):
if self._initialized:
return
@@ -962,7 +974,7 @@ class DailyInputTransport(BaseInputTransport):
await super().start(frame)
# Setup client.
await self._client.setup(frame)
await self._client.start(frame)
# Join the room.
await self._client.join()
@@ -993,11 +1005,6 @@ class DailyInputTransport(BaseInputTransport):
await self.cancel_task(self._audio_in_task)
self._audio_in_task = None
async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
await self._transport.cleanup()
#
# FrameProcessor
#
@@ -1139,6 +1146,15 @@ class DailyOutputTransport(BaseOutputTransport):
# Whether we have seen a StartFrame already.
self._initialized = False
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await self._client.setup(setup)
async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
await self._transport.cleanup()
async def start(self, frame: StartFrame):
if self._initialized:
return
@@ -1149,7 +1165,7 @@ class DailyOutputTransport(BaseOutputTransport):
await super().start(frame)
# Setup client.
await self._client.setup(frame)
await self._client.start(frame)
# Join the room.
await self._client.join()
@@ -1169,11 +1185,6 @@ class DailyOutputTransport(BaseOutputTransport):
# Leave the room.
await self._client.leave()
async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
await self._transport.cleanup()
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._client.send_message(frame)

View File

@@ -23,7 +23,7 @@ from pipecat.frames.frames import (
TransportMessageFrame,
TransportMessageUrgentFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
@@ -100,20 +100,27 @@ class LiveKitTransportClient:
raise Exception(f"{self}: missing room object (pipeline not started?)")
return self._room
async def setup(self, frame: StartFrame):
self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate
if not self._task_manager:
self._task_manager = frame.task_manager
self._room = rtc.Room(loop=self._task_manager.get_event_loop())
async def setup(self, setup: FrameProcessorSetup):
if self._task_manager:
return
# Set up room event handlers
self.room.on("participant_connected")(self._on_participant_connected_wrapper)
self.room.on("participant_disconnected")(self._on_participant_disconnected_wrapper)
self.room.on("track_subscribed")(self._on_track_subscribed_wrapper)
self.room.on("track_unsubscribed")(self._on_track_unsubscribed_wrapper)
self.room.on("data_received")(self._on_data_received_wrapper)
self.room.on("connected")(self._on_connected_wrapper)
self.room.on("disconnected")(self._on_disconnected_wrapper)
self._task_manager = setup.task_manager
self._room = rtc.Room(loop=self._task_manager.get_event_loop())
# Set up room event handlers
self.room.on("participant_connected")(self._on_participant_connected_wrapper)
self.room.on("participant_disconnected")(self._on_participant_disconnected_wrapper)
self.room.on("track_subscribed")(self._on_track_subscribed_wrapper)
self.room.on("track_unsubscribed")(self._on_track_unsubscribed_wrapper)
self.room.on("data_received")(self._on_data_received_wrapper)
self.room.on("connected")(self._on_connected_wrapper)
self.room.on("disconnected")(self._on_disconnected_wrapper)
async def cleanup(self):
await self.disconnect()
async def start(self, frame: StartFrame):
self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
async def connect(self):
@@ -333,9 +340,6 @@ class LiveKitTransportClient:
else:
logger.warning(f"Received unexpected event type: {type(event)}")
async def cleanup(self):
await self.disconnect()
async def get_next_audio_frame(self):
frame, participant_id = await self._audio_queue.get()
return frame, participant_id
@@ -366,7 +370,7 @@ class LiveKitInputTransport(BaseInputTransport):
async def start(self, frame: StartFrame):
await super().start(frame)
await self._client.setup(frame)
await self._client.start(frame)
await self._client.connect()
if not self._audio_in_task and self._params.audio_in_enabled:
self._audio_in_task = self.create_task(self._audio_in_task_handler())
@@ -386,6 +390,10 @@ class LiveKitInputTransport(BaseInputTransport):
if self._audio_in_task and self._params.audio_in_enabled:
await self.cancel_task(self._audio_in_task)
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await self._client.setup(setup)
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()
@@ -440,7 +448,7 @@ class LiveKitOutputTransport(BaseOutputTransport):
async def start(self, frame: StartFrame):
await super().start(frame)
await self._client.setup(frame)
await self._client.start(frame)
await self._client.connect()
await self.set_transport_ready(frame)
logger.info("LiveKitOutputTransport started")
@@ -454,6 +462,10 @@ class LiveKitOutputTransport(BaseOutputTransport):
await super().cancel(frame)
await self._client.disconnect()
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await self._client.setup(setup)
async def cleanup(self):
await super().cleanup()
await self._transport.cleanup()