Merge pull request #1829 from pipecat-ai/aleix/pipeline-task-add-observer

PipelineTask: add add_observer()
This commit is contained in:
Aleix Conchillo Flaqué
2025-05-20 13:18:24 -07:00
committed by GitHub
4 changed files with 116 additions and 25 deletions

View File

@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `PipelineTask.add_observer()` and `PipelineTask.remove_observer()` to
allow mangaging observers at runtime. This is useful for cases where the task
is passed around to other code components that might want to observe the
pipeline dynamically.
- Added `user_id` field to `TranscriptionMessage`. This allows identifying the
user in a multi-user scenario. Note that this requires that
`TranscriptionFrame` has the `user_id` properly set.

View File

@@ -225,14 +225,16 @@ class PipelineTask(BaseTask):
)
observers = self._params.observers
observers = observers or []
self._turn_tracking_observer: Optional[TurnTrackingObserver] = None
self._turn_trace_observer: Optional[TurnTraceObserver] = None
if self._enable_turn_tracking:
self._turn_tracking_observer = TurnTrackingObserver()
observers = [self._turn_tracking_observer] + list(observers)
if self._enable_turn_tracking and self._enable_tracing:
observers.append(self._turn_tracking_observer)
if self._enable_tracing and self._turn_tracking_observer:
self._turn_trace_observer = TurnTraceObserver(
self._turn_tracking_observer, conversation_id=self._conversation_id
)
observers = [self._turn_trace_observer] + list(observers)
observers.append(self._turn_trace_observer)
self._finished = False
# This queue receives frames coming from the pipeline upstream.
@@ -297,12 +299,18 @@ class PipelineTask(BaseTask):
@property
def turn_tracking_observer(self) -> Optional[TurnTrackingObserver]:
"""Return the turn tracking observer if enabled."""
return getattr(self, "_turn_tracking_observer", None)
return self._turn_tracking_observer
@property
def turn_trace_observer(self) -> Optional[TurnTraceObserver]:
"""Return the turn trace observer if enabled."""
return getattr(self, "_turn_trace_observer", None)
return self._turn_trace_observer
async def add_observer(self, observer: BaseObserver):
await self._observer.add_observer(observer)
async def remove_observer(self, observer: BaseObserver):
await self._observer.remove_observer(observer)
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
self._task_manager.set_event_loop(loop)

View File

@@ -6,7 +6,7 @@
import asyncio
import inspect
from typing import List, Optional
from typing import Dict, List, Optional
from attr import dataclass
@@ -44,7 +44,22 @@ class TaskObserver(BaseObserver):
):
self._observers = observers or []
self._task_manager = task_manager
self._proxies: List[Proxy] = []
self._proxies: Dict[BaseObserver, Proxy] = {}
async def add_observer(self, observer: BaseObserver):
proxy = self._create_proxy(observer)
self._proxies[observer] = proxy
self._observers.append(observer)
async def remove_observer(self, observer: BaseObserver):
if observer in self._proxies:
proxy = self._proxies[observer]
# Remove the proxy so it doesn't get called anymore.
del self._proxies[observer]
# Cancel the proxy task right away.
await self._task_manager.cancel_task(proxy.task)
# Remove the observer.
self._observers.remove(observer)
async def start(self):
"""Starts all proxy observer tasks."""
@@ -52,23 +67,27 @@ class TaskObserver(BaseObserver):
async def stop(self):
"""Stops all proxy observer tasks."""
for proxy in self._proxies:
for proxy in self._proxies.values():
await self._task_manager.cancel_task(proxy.task)
async def on_push_frame(self, data: FramePushed):
for proxy in self._proxies:
for proxy in self._proxies.values():
await proxy.queue.put(data)
def _create_proxies(self, observers) -> List[Proxy]:
proxies = []
def _create_proxy(self, observer: BaseObserver) -> Proxy:
queue = asyncio.Queue()
task = self._task_manager.create_task(
self._proxy_task_handler(queue, observer),
f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler",
)
proxy = Proxy(queue=queue, task=task, observer=observer)
return proxy
def _create_proxies(self, observers: List[BaseObserver]) -> Dict[BaseObserver, Proxy]:
proxies = {}
for observer in observers:
queue = asyncio.Queue()
task = self._task_manager.create_task(
self._proxy_task_handler(queue, observer),
f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler",
)
proxy = Proxy(queue=queue, task=task, observer=observer)
proxies.append(proxy)
proxy = self._create_proxy(observer)
proxies[observer] = proxy
return proxies
async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver):

View File

@@ -8,13 +8,8 @@ import asyncio
import time
import unittest
from pipecat.frames.frames import (
EndFrame,
HeartbeatFrame,
StartFrame,
StopFrame,
TextFrame,
)
from pipecat.frames.frames import EndFrame, HeartbeatFrame, StartFrame, StopFrame, TextFrame
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -101,6 +96,70 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
await task.run()
assert task.has_finished()
async def test_task_observers(self):
frame_received = False
class CustomObserver(BaseObserver):
async def on_push_frame(self, data: FramePushed):
nonlocal frame_received
if isinstance(data.frame, TextFrame):
frame_received = True
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline, observers=[CustomObserver()])
task.set_event_loop(asyncio.get_event_loop())
await task.queue_frames([TextFrame(text="Hello Downstream!"), EndFrame()])
await task.run()
assert frame_received
async def test_task_add_observer(self):
frame_received = False
frame_add_count = 0
class CustomObserver(BaseObserver):
async def on_push_frame(self, data: FramePushed):
nonlocal frame_received
if isinstance(data.frame, TextFrame):
frame_received = True
class CustomAddObserver(BaseObserver):
async def on_push_frame(self, data: FramePushed):
nonlocal frame_add_count
if isinstance(data.source, IdentityFilter) and isinstance(data.frame, TextFrame):
frame_add_count += 1
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline, observers=[CustomObserver()])
task.set_event_loop(asyncio.get_event_loop())
async def delayed_add_observer():
observer = CustomAddObserver()
# Wait after the pipeline is started and add an observer.
await asyncio.sleep(0.1)
await task.add_observer(observer)
# Push a TextFrame and wait for the observer to pick it up.
await task.queue_frame(TextFrame(text="Hello Downstream!"))
await asyncio.sleep(0.1)
# Remove the observer
await task.remove_observer(observer)
# Push another TextFrame. This time the counter should not
# increments since we have removed the observer.
await task.queue_frame(TextFrame(text="Hello Downstream!"))
await asyncio.sleep(0.1)
# Finally end the pipeline.
await task.queue_frame(EndFrame())
await asyncio.gather(task.run(), delayed_add_observer())
assert frame_received
assert frame_add_count == 1
async def test_task_started_ended_event_handler(self):
start_received = False
end_received = False