diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ca6ed7e6..776041481 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `on_frame_reached_upstream` and `on_frame_reached_downstream` event + handlers to `PipelineTask`. Those events will be called when a frame reaches + the beginning or end of the pipeline respectively. + - Added support for Chirp voices in `GoogleTTSService`. - Added a `flush_audio()` method to `FishTTSService` and `LmntTTSService`. @@ -80,7 +84,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fixed an issue in `RimeTTSService` where the last line of text sent didn't result in an audio output being generated. +- Fixed an issue in `RimeTTSService` where the last line of text sent didn't + result in an audio output being generated. ## [0.0.58] - 2025-02-26 diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 1c9d2dff9..f20ef9d1f 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -119,12 +119,25 @@ class PipelineTaskSink(FrameProcessor): class PipelineTask(BaseTask): """Manages the execution of a pipeline, handling frame processing and task lifecycle. + It has a couple of event handlers `on_frame_reached_upstream` and + `on_frame_reached_downstream` that are called when upstream frames or + downstream frames reach both ends of pipeline. + + @task.event_handler("on_frame_reached_upstream") + async def on_frame_reached_upstream(task, frame): + ... + + @task.event_handler("on_frame_reached_downstream") + async def on_frame_reached_downstream(task, frame): + ... + Args: pipeline: The pipeline to execute. params: Configuration parameters for the pipeline. observers: List of observers for monitoring pipeline execution. clock: Clock implementation for timing operations. check_dangling_tasks: Whether to check for processors' tasks finishing properly. + """ def __init__( @@ -177,6 +190,9 @@ class PipelineTask(BaseTask): self._observer = TaskObserver(observers=observers, task_manager=self._task_manager) + self._register_event_handler("on_frame_reached_upstream") + self._register_event_handler("on_frame_reached_downstream") + @property def params(self) -> PipelineParams: """Returns the pipeline parameters of this task.""" @@ -356,6 +372,9 @@ class PipelineTask(BaseTask): """ while True: frame = await self._up_queue.get() + + await self._call_event_handler("on_frame_reached_upstream", frame) + if isinstance(frame, EndTaskFrame): # Tell the task we should end nicely. await self.queue_frame(EndFrame()) @@ -383,6 +402,9 @@ class PipelineTask(BaseTask): """ while True: frame = await self._down_queue.get() + + await self._call_event_handler("on_frame_reached_downstream", frame) + if isinstance(frame, (EndFrame, StopFrame)): self._pipeline_end_event.set() elif isinstance(frame, HeartbeatFrame): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 07abef48e..a375812b2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,7 +12,7 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.filters.identity_filter import IdentityFilter -from pipecat.processors.frame_processor import FrameProcessor +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.tests.utils import HeartbeatsObserver, run_test @@ -94,6 +94,40 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase): await task.run() assert task.has_finished() + async def test_task_event_handlers(self): + upstream_received = False + downstream_received = False + + identity = IdentityFilter() + pipeline = Pipeline([identity]) + task = PipelineTask(pipeline) + task.set_event_loop(asyncio.get_event_loop()) + + @task.event_handler("on_frame_reached_upstream") + async def on_frame_reached_upstream(task, frame): + nonlocal upstream_received + if isinstance(frame, TextFrame) and frame.text == "Hello Upstream!": + upstream_received = True + + @task.event_handler("on_frame_reached_downstream") + async def on_frame_reached_downstream(task, frame): + nonlocal downstream_received + if isinstance(frame, TextFrame) and frame.text == "Hello Downstream!": + downstream_received = True + await identity.push_frame( + TextFrame(text="Hello Upstream!"), FrameDirection.UPSTREAM + ) + + await task.queue_frame(TextFrame(text="Hello Downstream!")) + + try: + await asyncio.wait_for(task.run(), timeout=1.0) + except asyncio.TimeoutError: + pass + + assert upstream_received + assert downstream_received + async def test_task_heartbeats(self): heartbeats_counter = 0