Merge pull request #1377 from pipecat-ai/aleix/task-upstream-downstream-filters

PipelineTask: only call event handlers if a filter is matched
This commit is contained in:
Aleix Conchillo Flaqué
2025-03-14 08:49:24 -07:00
committed by GitHub
3 changed files with 47 additions and 5 deletions

View File

@@ -14,7 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `on_frame_reached_upstream` and `on_frame_reached_downstream` event
handlers to `PipelineTask`. Those events will be called when a frame reaches
the beginning or end of the pipeline respectively.
the beginning or end of the pipeline respectively. Note that by default, the
event handlers will not be called unless a filter is set with
`PipelineTask.set_reached_upstream_filter()` or
`PipelineTask.set_reached_downstream_filter()`.
- Added support for Chirp voices in `GoogleTTSService`.

View File

@@ -5,7 +5,7 @@
#
import asyncio
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Tuple, Type
from loguru import logger
from pydantic import BaseModel, ConfigDict
@@ -121,7 +121,9 @@ class PipelineTask(BaseTask):
It has a couple of event handlers `on_frame_reached_upstream` and
`on_frame_reached_downstream` that are called when upstream frames or
downstream frames reach both ends of pipeline.
downstream frames reach both ends of pipeline. By default, the events
handlers will not be called unless some filters are set using
`set_reached_upstream_filter` and `set_reached_downstream_filter`.
@task.event_handler("on_frame_reached_upstream")
async def on_frame_reached_upstream(task, frame):
@@ -180,16 +182,35 @@ class PipelineTask(BaseTask):
# StopFrame) has been received in the down queue.
self._pipeline_end_event = asyncio.Event()
# This is a source processor that we connect to the provided
# pipeline. This source processor allows up to receive and react to
# upstream frames.
self._source = PipelineTaskSource(self._up_queue)
self._source.link(pipeline)
# This is a sink processor that we connect to the provided
# pipeline. This sink processor allows up to receive and react to
# downstream frames.
self._sink = PipelineTaskSink(self._down_queue)
pipeline.link(self._sink)
# This task maneger will handle all the asyncio tasks created by this
# PipelineTask and its frame processors.
self._task_manager = task_manager or TaskManager()
# The task observer acts as a proxy to the provided observers. This way,
# we only need to pass a single observer (using the StartFrame) which
# then just acts as a proxy.
self._observer = TaskObserver(observers=observers, task_manager=self._task_manager)
# These events can be used to check which frames make it to the source
# or sink processors. Instead of calling the event handlers for every
# frame the user needs to specify which events they are interested
# in. This is mainly for efficiency reason because each event handler
# creates a task and most likely you only care about one or two frame
# types.
self._reached_upstream_types: Tuple[Type[Frame], ...] = ()
self._reached_downstream_types: Tuple[Type[Frame], ...] = ()
self._register_event_handler("on_frame_reached_upstream")
self._register_event_handler("on_frame_reached_downstream")
@@ -201,6 +222,20 @@ class PipelineTask(BaseTask):
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
self._task_manager.set_event_loop(loop)
def set_reached_upstream_filter(self, types: Tuple[Type[Frame], ...]):
"""Sets which frames will be checked before calling the
on_frame_reached_upstream event handler.
"""
self._reached_upstream_types = types
def set_reached_downstream_filter(self, types: Tuple[Type[Frame], ...]):
"""Sets which frames will be checked before calling the
on_frame_reached_downstream event handler.
"""
self._reached_downstream_types = types
def has_finished(self) -> bool:
"""Indicates whether the tasks has finished. That is, all processors
have stopped.
@@ -373,7 +408,8 @@ class PipelineTask(BaseTask):
while True:
frame = await self._up_queue.get()
await self._call_event_handler("on_frame_reached_upstream", frame)
if isinstance(frame, self._reached_upstream_types):
await self._call_event_handler("on_frame_reached_upstream", frame)
if isinstance(frame, EndTaskFrame):
# Tell the task we should end nicely.
@@ -403,7 +439,8 @@ class PipelineTask(BaseTask):
while True:
frame = await self._down_queue.get()
await self._call_event_handler("on_frame_reached_downstream", frame)
if isinstance(frame, self._reached_downstream_types):
await self._call_event_handler("on_frame_reached_downstream", frame)
if isinstance(frame, (EndFrame, StopFrame)):
self._pipeline_end_event.set()

View File

@@ -102,6 +102,8 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
pipeline = Pipeline([identity])
task = PipelineTask(pipeline)
task.set_event_loop(asyncio.get_event_loop())
task.set_reached_upstream_filter((TextFrame,))
task.set_reached_downstream_filter((TextFrame,))
@task.event_handler("on_frame_reached_upstream")
async def on_frame_reached_upstream(task, frame):