diff --git a/CHANGELOG.md b/CHANGELOG.md index 8df67e742..1bb74c358 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- A new function `FrameProcessor.setup()` has been added to allow setting up + frame processors before receiving a `StartFrame`. This is what's happening + internally: `FrameProcessor.setup()` is called, `StartFrame` is pushed from + the beginning of the pipeline, your regular pipeline operations, `EndFrame` or + `CancelFrame` are pushed from the beginning of the pipeline and finally + `FrameProcessor.cleanup()` is called. + - Allow passing observers to `run_test()` while running unit tests. ### Changed diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 8d3f38459..2db21495b 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -7,7 +7,6 @@ from dataclasses import dataclass, field from enum import Enum from typing import ( - TYPE_CHECKING, Any, Awaitable, Callable, @@ -20,16 +19,11 @@ from typing import ( ) from pipecat.audio.vad.vad_analyzer import VADParams -from pipecat.clocks.base_clock import BaseClock from pipecat.metrics.metrics import MetricsData from pipecat.transcriptions.language import Language -from pipecat.utils.asyncio import BaseTaskManager from pipecat.utils.time import nanoseconds_to_str from pipecat.utils.utils import obj_count, obj_id -if TYPE_CHECKING: - from pipecat.observers.base_observer import BaseObserver - class KeypadEntry(str, Enum): """DTMF entries.""" @@ -447,14 +441,11 @@ class OutputDTMFFrame(DTMFFrame): class StartFrame(SystemFrame): """This is the first frame that should be pushed down a pipeline.""" - clock: BaseClock - task_manager: BaseTaskManager audio_in_sample_rate: int = 16000 audio_out_sample_rate: int = 24000 allow_interruptions: bool = False enable_metrics: bool = False enable_usage_metrics: bool = False - observer: Optional["BaseObserver"] = None report_only_initial_ttfb: bool = False diff --git a/src/pipecat/pipeline/parallel_pipeline.py b/src/pipecat/pipeline/parallel_pipeline.py index c6f4de2df..c82330107 100644 --- a/src/pipecat/pipeline/parallel_pipeline.py +++ b/src/pipecat/pipeline/parallel_pipeline.py @@ -20,7 +20,7 @@ from pipecat.frames.frames import ( ) from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.pipeline import Pipeline -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup class ParallelPipelineSource(FrameProcessor): @@ -118,6 +118,12 @@ class ParallelPipeline(BasePipeline): # Frame processor # + async def setup(self, setup: FrameProcessorSetup): + await super().setup(setup) + await asyncio.gather(*[s.setup(setup) for s in self._sources]) + await asyncio.gather(*[p.setup(setup) for p in self._pipelines]) + await asyncio.gather(*[s.setup(setup) for s in self._sinks]) + async def cleanup(self): await super().cleanup() await asyncio.gather(*[s.cleanup() for s in self._sources]) diff --git a/src/pipecat/pipeline/pipeline.py b/src/pipecat/pipeline/pipeline.py index 270e100bd..c10a32e0a 100644 --- a/src/pipecat/pipeline/pipeline.py +++ b/src/pipecat/pipeline/pipeline.py @@ -8,7 +8,7 @@ from typing import Callable, Coroutine, List from pipecat.frames.frames import Frame from pipecat.pipeline.base_pipeline import BasePipeline -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup class PipelineSource(FrameProcessor): @@ -70,6 +70,10 @@ class Pipeline(BasePipeline): # Frame processor # + async def setup(self, setup: FrameProcessorSetup): + await super().setup(setup) + await self._setup_processors(setup) + async def cleanup(self): await super().cleanup() await self._cleanup_processors() @@ -82,6 +86,10 @@ class Pipeline(BasePipeline): elif direction == FrameDirection.UPSTREAM: await self._sink.queue_frame(frame, FrameDirection.UPSTREAM) + async def _setup_processors(self, setup: FrameProcessorSetup): + for p in self._processors: + await p.setup(setup) + async def _cleanup_processors(self): for p in self._processors: await p.cleanup() diff --git a/src/pipecat/pipeline/sync_parallel_pipeline.py b/src/pipecat/pipeline/sync_parallel_pipeline.py index 2870a470d..4cf9f5033 100644 --- a/src/pipecat/pipeline/sync_parallel_pipeline.py +++ b/src/pipecat/pipeline/sync_parallel_pipeline.py @@ -14,7 +14,7 @@ from loguru import logger from pipecat.frames.frames import ControlFrame, EndFrame, Frame, SystemFrame from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.pipeline import Pipeline -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup @dataclass @@ -103,6 +103,12 @@ class SyncParallelPipeline(BasePipeline): # Frame processor # + async def setup(self, setup: FrameProcessorSetup): + await super().setup(setup) + await asyncio.gather(*[s["processor"].setup(setup) for s in self._sources]) + await asyncio.gather(*[p.setup(setup) for p in self._pipelines]) + await asyncio.gather(*[s["processor"].setup(setup) for s in self._sinks]) + async def cleanup(self): await super().cleanup() await asyncio.gather(*[s["processor"].cleanup() for s in self._sources]) diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index c40173899..679bb02e4 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -33,7 +33,7 @@ from pipecat.observers.base_observer import BaseObserver from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.base_task import BaseTask from pipecat.pipeline.task_observer import TaskObserver -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup from pipecat.utils.asyncio import BaseTaskManager, TaskManager HEARTBEAT_SECONDS = 1.0 @@ -294,8 +294,15 @@ class PipelineTask(BaseTask): return cleanup_pipeline = True try: + # Setup processors. + await self._setup() + + # Create all main tasks and wait of the main push task. This is the + # task that pushes frames to the very beginning of our pipeline (our + # controlled PipelineTaskSource processor). push_task = await self._create_tasks() await self._task_manager.wait_for_task(push_task) + # We have already cleaned up the pipeline inside the task. cleanup_pipeline = False except asyncio.CancelledError: @@ -405,6 +412,16 @@ class PipelineTask(BaseTask): await self._pipeline_end_event.wait() self._pipeline_end_event.clear() + async def _setup(self): + setup = FrameProcessorSetup( + clock=self._clock, + task_manager=self._task_manager, + observer=self._observer, + ) + await self._source.setup(setup) + await self._pipeline.setup(setup) + await self._sink.setup(setup) + async def _cleanup(self, cleanup_pipeline: bool): # Cleanup base object. await self.cleanup() @@ -427,14 +444,11 @@ class PipelineTask(BaseTask): self._maybe_start_idle_task() start_frame = StartFrame( - clock=self._clock, - task_manager=self._task_manager, allow_interruptions=self._params.allow_interruptions, audio_in_sample_rate=self._params.audio_in_sample_rate, audio_out_sample_rate=self._params.audio_out_sample_rate, enable_metrics=self._params.enable_metrics, enable_usage_metrics=self._params.enable_usage_metrics, - observer=self._observer, report_only_initial_ttfb=self._params.report_only_initial_ttfb, ) start_frame.metadata = self._params.start_metadata diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 97cc24378..b444b4c58 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -5,6 +5,7 @@ # import asyncio +from dataclasses import dataclass from enum import Enum from typing import Awaitable, Callable, Coroutine, Optional @@ -21,7 +22,7 @@ from pipecat.frames.frames import ( SystemFrame, ) from pipecat.metrics.metrics import LLMTokenUsage, MetricsData -from pipecat.observers.base_observer import FramePushed +from pipecat.observers.base_observer import BaseObserver, FramePushed from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics from pipecat.utils.asyncio import BaseTaskManager from pipecat.utils.base_object import BaseObject @@ -32,6 +33,13 @@ class FrameDirection(Enum): UPSTREAM = 2 +@dataclass +class FrameProcessorSetup: + clock: BaseClock + task_manager: BaseTaskManager + observer: Optional[BaseObserver] = None + + class FrameProcessor(BaseObject): def __init__( self, @@ -51,12 +59,17 @@ class FrameProcessor(BaseObject): # Task Manager self._task_manager: Optional[BaseTaskManager] = None + # Observer + self._observer: Optional[BaseObserver] = None + # Other properties self._allow_interruptions = False self._enable_metrics = False self._enable_usage_metrics = False self._report_only_initial_ttfb = False - self._observer = None + + # Indicates whether we have received the StartFrame. + self.__started = False # Cancellation is done through CancelFrame (a system frame). This could # cause other events being triggered (e.g. closing a transport) which @@ -167,6 +180,11 @@ class FrameProcessor(BaseObject): raise Exception(f"{self} TaskManager is still not initialized.") await self._task_manager.wait_for_task(task, timeout) + async def setup(self, setup: FrameProcessorSetup): + self._clock = setup.clock + self._task_manager = setup.task_manager + self._observer = setup.observer + async def cleanup(self): await super().cleanup() await self.__cancel_input_task() @@ -227,13 +245,6 @@ class FrameProcessor(BaseObject): async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, StartFrame): - self._clock = frame.clock - self._task_manager = frame.task_manager - self._allow_interruptions = frame.allow_interruptions - self._enable_metrics = frame.enable_metrics - self._enable_usage_metrics = frame.enable_usage_metrics - self._report_only_initial_ttfb = frame.report_only_initial_ttfb - self._observer = frame.observer await self.__start(frame) elif isinstance(frame, StartInterruptionFrame): await self._start_interruption() @@ -247,7 +258,7 @@ class FrameProcessor(BaseObject): await self.push_frame(error, FrameDirection.UPSTREAM) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): - if not self._check_ready(frame): + if not self._check_started(frame): return if isinstance(frame, SystemFrame): @@ -256,6 +267,11 @@ class FrameProcessor(BaseObject): await self.__push_queue.put((frame, direction)) async def __start(self, frame: StartFrame): + self.__started = True + self._allow_interruptions = frame.allow_interruptions + self._enable_metrics = frame.enable_metrics + self._enable_usage_metrics = frame.enable_usage_metrics + self._report_only_initial_ttfb = frame.report_only_initial_ttfb self.__create_input_task() self.__create_push_task() @@ -323,15 +339,10 @@ class FrameProcessor(BaseObject): await self.push_error(ErrorFrame(str(e))) raise - def _check_ready(self, frame: Frame): - # If we are trying to push a frame but we still have no clock, it means - # we didn't process a StartFrame. - if not self._clock: - logger.error( - f"{self} not properly initialized, missing super().process_frame(frame, direction)?" - ) - return False - return True + def _check_started(self, frame: Frame): + if not self.__started: + logger.error(f"{self} Trying to process {frame} but StartFrame not received yet") + return self.__started def __create_input_task(self): if not self.__input_frame_task: diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 4528662a6..b31396433 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -34,7 +34,7 @@ from pipecat.frames.frames import ( UserImageRawFrame, UserImageRequestFrame, ) -from pipecat.processors.frame_processor import FrameDirection +from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup from pipecat.transcriptions.language import Language from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport @@ -410,7 +410,25 @@ class DailyTransportClient(EventHandler): if not destination and self._camera: self._camera.write_frame(frame.image) - async def setup(self, frame: StartFrame): + async def setup(self, setup: FrameProcessorSetup): + if self._task_manager: + return + + self._task_manager = setup.task_manager + self._callback_task = self._task_manager.create_task( + self._callback_task_handler(), + f"{self}::callback_task", + ) + + async def cleanup(self): + if self._callback_task and self._task_manager: + await self._task_manager.cancel_task(self._callback_task) + self._callback_task = None + # Make sure we don't block the event loop in case `client.release()` + # takes extra time. + await self._get_event_loop().run_in_executor(self._executor, self._cleanup) + + async def start(self, frame: StartFrame): self._in_sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate @@ -439,13 +457,6 @@ class DailyTransportClient(EventHandler): ) Daily.select_speaker_device(self._speaker_name()) - if not self._task_manager: - self._task_manager = frame.task_manager - self._callback_task = self._task_manager.create_task( - self._callback_task_handler(), - f"{self}::callback_task", - ) - async def join(self): # Transport already joined or joining, ignore. if self._joined or self._joining: @@ -612,14 +623,6 @@ class DailyTransportClient(EventHandler): self._client.leave(completion=completion_callback(future)) return await asyncio.wait_for(future, timeout=10) - async def cleanup(self): - if self._callback_task and self._task_manager: - await self._task_manager.cancel_task(self._callback_task) - self._callback_task = None - # Make sure we don't block the event loop in case `client.release()` - # takes extra time. - await self._get_event_loop().run_in_executor(self._executor, self._cleanup) - def _cleanup(self): if self._client: self._client.release() @@ -952,6 +955,15 @@ class DailyInputTransport(BaseInputTransport): logger.debug(f"Start receiving audio") self._audio_in_task = self.create_task(self._audio_in_task_handler()) + async def setup(self, setup: FrameProcessorSetup): + await super().setup(setup) + await self._client.setup(setup) + + async def cleanup(self): + await super().cleanup() + await self._client.cleanup() + await self._transport.cleanup() + async def start(self, frame: StartFrame): if self._initialized: return @@ -962,7 +974,7 @@ class DailyInputTransport(BaseInputTransport): await super().start(frame) # Setup client. - await self._client.setup(frame) + await self._client.start(frame) # Join the room. await self._client.join() @@ -993,11 +1005,6 @@ class DailyInputTransport(BaseInputTransport): await self.cancel_task(self._audio_in_task) self._audio_in_task = None - async def cleanup(self): - await super().cleanup() - await self._client.cleanup() - await self._transport.cleanup() - # # FrameProcessor # @@ -1139,6 +1146,15 @@ class DailyOutputTransport(BaseOutputTransport): # Whether we have seen a StartFrame already. self._initialized = False + async def setup(self, setup: FrameProcessorSetup): + await super().setup(setup) + await self._client.setup(setup) + + async def cleanup(self): + await super().cleanup() + await self._client.cleanup() + await self._transport.cleanup() + async def start(self, frame: StartFrame): if self._initialized: return @@ -1149,7 +1165,7 @@ class DailyOutputTransport(BaseOutputTransport): await super().start(frame) # Setup client. - await self._client.setup(frame) + await self._client.start(frame) # Join the room. await self._client.join() @@ -1169,11 +1185,6 @@ class DailyOutputTransport(BaseOutputTransport): # Leave the room. await self._client.leave() - async def cleanup(self): - await super().cleanup() - await self._client.cleanup() - await self._transport.cleanup() - async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): await self._client.send_message(frame) diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index 36cc5d604..4f215d04f 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -23,7 +23,7 @@ from pipecat.frames.frames import ( TransportMessageFrame, TransportMessageUrgentFrame, ) -from pipecat.processors.frame_processor import FrameDirection +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams @@ -100,20 +100,27 @@ class LiveKitTransportClient: raise Exception(f"{self}: missing room object (pipeline not started?)") return self._room - async def setup(self, frame: StartFrame): - self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate - if not self._task_manager: - self._task_manager = frame.task_manager - self._room = rtc.Room(loop=self._task_manager.get_event_loop()) + async def setup(self, setup: FrameProcessorSetup): + if self._task_manager: + return - # Set up room event handlers - self.room.on("participant_connected")(self._on_participant_connected_wrapper) - self.room.on("participant_disconnected")(self._on_participant_disconnected_wrapper) - self.room.on("track_subscribed")(self._on_track_subscribed_wrapper) - self.room.on("track_unsubscribed")(self._on_track_unsubscribed_wrapper) - self.room.on("data_received")(self._on_data_received_wrapper) - self.room.on("connected")(self._on_connected_wrapper) - self.room.on("disconnected")(self._on_disconnected_wrapper) + self._task_manager = setup.task_manager + self._room = rtc.Room(loop=self._task_manager.get_event_loop()) + + # Set up room event handlers + self.room.on("participant_connected")(self._on_participant_connected_wrapper) + self.room.on("participant_disconnected")(self._on_participant_disconnected_wrapper) + self.room.on("track_subscribed")(self._on_track_subscribed_wrapper) + self.room.on("track_unsubscribed")(self._on_track_unsubscribed_wrapper) + self.room.on("data_received")(self._on_data_received_wrapper) + self.room.on("connected")(self._on_connected_wrapper) + self.room.on("disconnected")(self._on_disconnected_wrapper) + + async def cleanup(self): + await self.disconnect() + + async def start(self, frame: StartFrame): + self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) async def connect(self): @@ -333,9 +340,6 @@ class LiveKitTransportClient: else: logger.warning(f"Received unexpected event type: {type(event)}") - async def cleanup(self): - await self.disconnect() - async def get_next_audio_frame(self): frame, participant_id = await self._audio_queue.get() return frame, participant_id @@ -366,7 +370,7 @@ class LiveKitInputTransport(BaseInputTransport): async def start(self, frame: StartFrame): await super().start(frame) - await self._client.setup(frame) + await self._client.start(frame) await self._client.connect() if not self._audio_in_task and self._params.audio_in_enabled: self._audio_in_task = self.create_task(self._audio_in_task_handler()) @@ -386,6 +390,10 @@ class LiveKitInputTransport(BaseInputTransport): if self._audio_in_task and self._params.audio_in_enabled: await self.cancel_task(self._audio_in_task) + async def setup(self, setup: FrameProcessorSetup): + await super().setup(setup) + await self._client.setup(setup) + async def cleanup(self): await super().cleanup() await self._transport.cleanup() @@ -440,7 +448,7 @@ class LiveKitOutputTransport(BaseOutputTransport): async def start(self, frame: StartFrame): await super().start(frame) - await self._client.setup(frame) + await self._client.start(frame) await self._client.connect() await self.set_transport_ready(frame) logger.info("LiveKitOutputTransport started") @@ -454,6 +462,10 @@ class LiveKitOutputTransport(BaseOutputTransport): await super().cancel(frame) await self._client.disconnect() + async def setup(self, setup: FrameProcessorSetup): + await super().setup(setup) + await self._client.setup(setup) + async def cleanup(self): await super().cleanup() await self._transport.cleanup()