Merge pull request #1377 from pipecat-ai/aleix/task-upstream-downstream-filters
PipelineTask: only call event handlers if a filter is matched
This commit is contained in:
@@ -14,7 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Added `on_frame_reached_upstream` and `on_frame_reached_downstream` event
|
||||
handlers to `PipelineTask`. Those events will be called when a frame reaches
|
||||
the beginning or end of the pipeline respectively.
|
||||
the beginning or end of the pipeline respectively. Note that by default, the
|
||||
event handlers will not be called unless a filter is set with
|
||||
`PipelineTask.set_reached_upstream_filter()` or
|
||||
`PipelineTask.set_reached_downstream_filter()`.
|
||||
|
||||
- Added support for Chirp voices in `GoogleTTSService`.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional
|
||||
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -121,7 +121,9 @@ class PipelineTask(BaseTask):
|
||||
|
||||
It has a couple of event handlers `on_frame_reached_upstream` and
|
||||
`on_frame_reached_downstream` that are called when upstream frames or
|
||||
downstream frames reach both ends of pipeline.
|
||||
downstream frames reach both ends of pipeline. By default, the events
|
||||
handlers will not be called unless some filters are set using
|
||||
`set_reached_upstream_filter` and `set_reached_downstream_filter`.
|
||||
|
||||
@task.event_handler("on_frame_reached_upstream")
|
||||
async def on_frame_reached_upstream(task, frame):
|
||||
@@ -180,16 +182,35 @@ class PipelineTask(BaseTask):
|
||||
# StopFrame) has been received in the down queue.
|
||||
self._pipeline_end_event = asyncio.Event()
|
||||
|
||||
# This is a source processor that we connect to the provided
|
||||
# pipeline. This source processor allows up to receive and react to
|
||||
# upstream frames.
|
||||
self._source = PipelineTaskSource(self._up_queue)
|
||||
self._source.link(pipeline)
|
||||
|
||||
# This is a sink processor that we connect to the provided
|
||||
# pipeline. This sink processor allows up to receive and react to
|
||||
# downstream frames.
|
||||
self._sink = PipelineTaskSink(self._down_queue)
|
||||
pipeline.link(self._sink)
|
||||
|
||||
# This task maneger will handle all the asyncio tasks created by this
|
||||
# PipelineTask and its frame processors.
|
||||
self._task_manager = task_manager or TaskManager()
|
||||
|
||||
# The task observer acts as a proxy to the provided observers. This way,
|
||||
# we only need to pass a single observer (using the StartFrame) which
|
||||
# then just acts as a proxy.
|
||||
self._observer = TaskObserver(observers=observers, task_manager=self._task_manager)
|
||||
|
||||
# These events can be used to check which frames make it to the source
|
||||
# or sink processors. Instead of calling the event handlers for every
|
||||
# frame the user needs to specify which events they are interested
|
||||
# in. This is mainly for efficiency reason because each event handler
|
||||
# creates a task and most likely you only care about one or two frame
|
||||
# types.
|
||||
self._reached_upstream_types: Tuple[Type[Frame], ...] = ()
|
||||
self._reached_downstream_types: Tuple[Type[Frame], ...] = ()
|
||||
self._register_event_handler("on_frame_reached_upstream")
|
||||
self._register_event_handler("on_frame_reached_downstream")
|
||||
|
||||
@@ -201,6 +222,20 @@ class PipelineTask(BaseTask):
|
||||
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
self._task_manager.set_event_loop(loop)
|
||||
|
||||
def set_reached_upstream_filter(self, types: Tuple[Type[Frame], ...]):
|
||||
"""Sets which frames will be checked before calling the
|
||||
on_frame_reached_upstream event handler.
|
||||
|
||||
"""
|
||||
self._reached_upstream_types = types
|
||||
|
||||
def set_reached_downstream_filter(self, types: Tuple[Type[Frame], ...]):
|
||||
"""Sets which frames will be checked before calling the
|
||||
on_frame_reached_downstream event handler.
|
||||
|
||||
"""
|
||||
self._reached_downstream_types = types
|
||||
|
||||
def has_finished(self) -> bool:
|
||||
"""Indicates whether the tasks has finished. That is, all processors
|
||||
have stopped.
|
||||
@@ -373,7 +408,8 @@ class PipelineTask(BaseTask):
|
||||
while True:
|
||||
frame = await self._up_queue.get()
|
||||
|
||||
await self._call_event_handler("on_frame_reached_upstream", frame)
|
||||
if isinstance(frame, self._reached_upstream_types):
|
||||
await self._call_event_handler("on_frame_reached_upstream", frame)
|
||||
|
||||
if isinstance(frame, EndTaskFrame):
|
||||
# Tell the task we should end nicely.
|
||||
@@ -403,7 +439,8 @@ class PipelineTask(BaseTask):
|
||||
while True:
|
||||
frame = await self._down_queue.get()
|
||||
|
||||
await self._call_event_handler("on_frame_reached_downstream", frame)
|
||||
if isinstance(frame, self._reached_downstream_types):
|
||||
await self._call_event_handler("on_frame_reached_downstream", frame)
|
||||
|
||||
if isinstance(frame, (EndFrame, StopFrame)):
|
||||
self._pipeline_end_event.set()
|
||||
|
||||
@@ -102,6 +102,8 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
task.set_reached_upstream_filter((TextFrame,))
|
||||
task.set_reached_downstream_filter((TextFrame,))
|
||||
|
||||
@task.event_handler("on_frame_reached_upstream")
|
||||
async def on_frame_reached_upstream(task, frame):
|
||||
|
||||
Reference in New Issue
Block a user