PipelineTask: add on_frame_reached_upstream/on_frame_reached_downstream
This commit is contained in:
@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `on_frame_reached_upstream` and `on_frame_reached_downstream` event
|
||||
handlers to `PipelineTask`. Those events will be called when a frame reaches
|
||||
the beginning or end of the pipeline respectively.
|
||||
|
||||
- Added support for Chirp voices in `GoogleTTSService`.
|
||||
|
||||
- Added a `flush_audio()` method to `FishTTSService` and `LmntTTSService`.
|
||||
@@ -80,7 +84,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in `RimeTTSService` where the last line of text sent didn't result in an audio output being generated.
|
||||
- Fixed an issue in `RimeTTSService` where the last line of text sent didn't
|
||||
result in an audio output being generated.
|
||||
|
||||
## [0.0.58] - 2025-02-26
|
||||
|
||||
|
||||
@@ -119,12 +119,25 @@ class PipelineTaskSink(FrameProcessor):
|
||||
class PipelineTask(BaseTask):
|
||||
"""Manages the execution of a pipeline, handling frame processing and task lifecycle.
|
||||
|
||||
It has a couple of event handlers `on_frame_reached_upstream` and
|
||||
`on_frame_reached_downstream` that are called when upstream frames or
|
||||
downstream frames reach both ends of pipeline.
|
||||
|
||||
@task.event_handler("on_frame_reached_upstream")
|
||||
async def on_frame_reached_upstream(task, frame):
|
||||
...
|
||||
|
||||
@task.event_handler("on_frame_reached_downstream")
|
||||
async def on_frame_reached_downstream(task, frame):
|
||||
...
|
||||
|
||||
Args:
|
||||
pipeline: The pipeline to execute.
|
||||
params: Configuration parameters for the pipeline.
|
||||
observers: List of observers for monitoring pipeline execution.
|
||||
clock: Clock implementation for timing operations.
|
||||
check_dangling_tasks: Whether to check for processors' tasks finishing properly.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -177,6 +190,9 @@ class PipelineTask(BaseTask):
|
||||
|
||||
self._observer = TaskObserver(observers=observers, task_manager=self._task_manager)
|
||||
|
||||
self._register_event_handler("on_frame_reached_upstream")
|
||||
self._register_event_handler("on_frame_reached_downstream")
|
||||
|
||||
@property
|
||||
def params(self) -> PipelineParams:
|
||||
"""Returns the pipeline parameters of this task."""
|
||||
@@ -356,6 +372,9 @@ class PipelineTask(BaseTask):
|
||||
"""
|
||||
while True:
|
||||
frame = await self._up_queue.get()
|
||||
|
||||
await self._call_event_handler("on_frame_reached_upstream", frame)
|
||||
|
||||
if isinstance(frame, EndTaskFrame):
|
||||
# Tell the task we should end nicely.
|
||||
await self.queue_frame(EndFrame())
|
||||
@@ -383,6 +402,9 @@ class PipelineTask(BaseTask):
|
||||
"""
|
||||
while True:
|
||||
frame = await self._down_queue.get()
|
||||
|
||||
await self._call_event_handler("on_frame_reached_downstream", frame)
|
||||
|
||||
if isinstance(frame, (EndFrame, StopFrame)):
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
|
||||
@@ -12,7 +12,7 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests.utils import HeartbeatsObserver, run_test
|
||||
|
||||
|
||||
@@ -94,6 +94,40 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
await task.run()
|
||||
assert task.has_finished()
|
||||
|
||||
async def test_task_event_handlers(self):
|
||||
upstream_received = False
|
||||
downstream_received = False
|
||||
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
@task.event_handler("on_frame_reached_upstream")
|
||||
async def on_frame_reached_upstream(task, frame):
|
||||
nonlocal upstream_received
|
||||
if isinstance(frame, TextFrame) and frame.text == "Hello Upstream!":
|
||||
upstream_received = True
|
||||
|
||||
@task.event_handler("on_frame_reached_downstream")
|
||||
async def on_frame_reached_downstream(task, frame):
|
||||
nonlocal downstream_received
|
||||
if isinstance(frame, TextFrame) and frame.text == "Hello Downstream!":
|
||||
downstream_received = True
|
||||
await identity.push_frame(
|
||||
TextFrame(text="Hello Upstream!"), FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await task.queue_frame(TextFrame(text="Hello Downstream!"))
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(task.run(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
assert upstream_received
|
||||
assert downstream_received
|
||||
|
||||
async def test_task_heartbeats(self):
|
||||
heartbeats_counter = 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user