Merge pull request #1368 from pipecat-ai/aleix/pipelinetask-frame-event-handlers

PipelineTask: add on_frame_reached_upstream and on_frame_reached_downstream
This commit is contained in:
Aleix Conchillo Flaqué
2025-03-13 10:31:33 -07:00
committed by GitHub
3 changed files with 63 additions and 2 deletions

View File

@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `on_frame_reached_upstream` and `on_frame_reached_downstream` event
handlers to `PipelineTask`. Those events will be called when a frame reaches
the beginning or end of the pipeline respectively.
- Added support for Chirp voices in `GoogleTTSService`.
- Added a `flush_audio()` method to `FishTTSService` and `LmntTTSService`.
@@ -80,7 +84,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed an issue in `RimeTTSService` where the last line of text sent didn't result in an audio output being generated.
- Fixed an issue in `RimeTTSService` where the last line of text sent didn't
result in an audio output being generated.
## [0.0.58] - 2025-02-26

View File

@@ -119,12 +119,25 @@ class PipelineTaskSink(FrameProcessor):
class PipelineTask(BaseTask):
"""Manages the execution of a pipeline, handling frame processing and task lifecycle.
It has a couple of event handlers `on_frame_reached_upstream` and
`on_frame_reached_downstream` that are called when upstream frames or
downstream frames reach both ends of pipeline.
@task.event_handler("on_frame_reached_upstream")
async def on_frame_reached_upstream(task, frame):
...
@task.event_handler("on_frame_reached_downstream")
async def on_frame_reached_downstream(task, frame):
...
Args:
pipeline: The pipeline to execute.
params: Configuration parameters for the pipeline.
observers: List of observers for monitoring pipeline execution.
clock: Clock implementation for timing operations.
check_dangling_tasks: Whether to check for processors' tasks finishing properly.
"""
def __init__(
@@ -177,6 +190,9 @@ class PipelineTask(BaseTask):
self._observer = TaskObserver(observers=observers, task_manager=self._task_manager)
self._register_event_handler("on_frame_reached_upstream")
self._register_event_handler("on_frame_reached_downstream")
@property
def params(self) -> PipelineParams:
"""Returns the pipeline parameters of this task."""
@@ -356,6 +372,9 @@ class PipelineTask(BaseTask):
"""
while True:
frame = await self._up_queue.get()
await self._call_event_handler("on_frame_reached_upstream", frame)
if isinstance(frame, EndTaskFrame):
# Tell the task we should end nicely.
await self.queue_frame(EndFrame())
@@ -383,6 +402,9 @@ class PipelineTask(BaseTask):
"""
while True:
frame = await self._down_queue.get()
await self._call_event_handler("on_frame_reached_downstream", frame)
if isinstance(frame, (EndFrame, StopFrame)):
self._pipeline_end_event.set()
elif isinstance(frame, HeartbeatFrame):

View File

@@ -12,7 +12,7 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.filters.identity_filter import IdentityFilter
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests.utils import HeartbeatsObserver, run_test
@@ -94,6 +94,40 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
await task.run()
assert task.has_finished()
async def test_task_event_handlers(self):
upstream_received = False
downstream_received = False
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline)
task.set_event_loop(asyncio.get_event_loop())
@task.event_handler("on_frame_reached_upstream")
async def on_frame_reached_upstream(task, frame):
nonlocal upstream_received
if isinstance(frame, TextFrame) and frame.text == "Hello Upstream!":
upstream_received = True
@task.event_handler("on_frame_reached_downstream")
async def on_frame_reached_downstream(task, frame):
nonlocal downstream_received
if isinstance(frame, TextFrame) and frame.text == "Hello Downstream!":
downstream_received = True
await identity.push_frame(
TextFrame(text="Hello Upstream!"), FrameDirection.UPSTREAM
)
await task.queue_frame(TextFrame(text="Hello Downstream!"))
try:
await asyncio.wait_for(task.run(), timeout=1.0)
except asyncio.TimeoutError:
pass
assert upstream_received
assert downstream_received
async def test_task_heartbeats(self):
heartbeats_counter = 0