From d333094149c7d5f955d2d4e00dd4289d2f865a7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sat, 17 May 2025 17:19:15 -0700 Subject: [PATCH] PipelineTask: add add_observer() and remove_observer() --- CHANGELOG.md | 5 ++ src/pipecat/pipeline/task.py | 18 +++++-- src/pipecat/pipeline/task_observer.py | 45 ++++++++++++----- tests/test_pipeline.py | 73 ++++++++++++++++++++++++--- 4 files changed, 116 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d8db0d65d..6dde0bae5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `PipelineTask.add_observer()` and `PipelineTask.remove_observer()` to + allow mangaging observers at runtime. This is useful for cases where the task + is passed around to other code components that might want to observe the + pipeline dynamically. + - Added `user_id` field to `TranscriptionMessage`. This allows identifying the user in a multi-user scenario. Note that this requires that `TranscriptionFrame` has the `user_id` properly set. diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 6deee59f8..e47c0f849 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -225,14 +225,16 @@ class PipelineTask(BaseTask): ) observers = self._params.observers observers = observers or [] + self._turn_tracking_observer: Optional[TurnTrackingObserver] = None + self._turn_trace_observer: Optional[TurnTraceObserver] = None if self._enable_turn_tracking: self._turn_tracking_observer = TurnTrackingObserver() - observers = [self._turn_tracking_observer] + list(observers) - if self._enable_turn_tracking and self._enable_tracing: + observers.append(self._turn_tracking_observer) + if self._enable_tracing and self._turn_tracking_observer: self._turn_trace_observer = TurnTraceObserver( self._turn_tracking_observer, conversation_id=self._conversation_id ) - observers = [self._turn_trace_observer] + list(observers) + observers.append(self._turn_trace_observer) self._finished = False # This queue receives frames coming from the pipeline upstream. @@ -297,12 +299,18 @@ class PipelineTask(BaseTask): @property def turn_tracking_observer(self) -> Optional[TurnTrackingObserver]: """Return the turn tracking observer if enabled.""" - return getattr(self, "_turn_tracking_observer", None) + return self._turn_tracking_observer @property def turn_trace_observer(self) -> Optional[TurnTraceObserver]: """Return the turn trace observer if enabled.""" - return getattr(self, "_turn_trace_observer", None) + return self._turn_trace_observer + + async def add_observer(self, observer: BaseObserver): + await self._observer.add_observer(observer) + + async def remove_observer(self, observer: BaseObserver): + await self._observer.remove_observer(observer) def set_event_loop(self, loop: asyncio.AbstractEventLoop): self._task_manager.set_event_loop(loop) diff --git a/src/pipecat/pipeline/task_observer.py b/src/pipecat/pipeline/task_observer.py index 127ea3466..c4d3d9a26 100644 --- a/src/pipecat/pipeline/task_observer.py +++ b/src/pipecat/pipeline/task_observer.py @@ -6,7 +6,7 @@ import asyncio import inspect -from typing import List, Optional +from typing import Dict, List, Optional from attr import dataclass @@ -44,7 +44,22 @@ class TaskObserver(BaseObserver): ): self._observers = observers or [] self._task_manager = task_manager - self._proxies: List[Proxy] = [] + self._proxies: Dict[BaseObserver, Proxy] = {} + + async def add_observer(self, observer: BaseObserver): + proxy = self._create_proxy(observer) + self._proxies[observer] = proxy + self._observers.append(observer) + + async def remove_observer(self, observer: BaseObserver): + if observer in self._proxies: + proxy = self._proxies[observer] + # Remove the proxy so it doesn't get called anymore. + del self._proxies[observer] + # Cancel the proxy task right away. + await self._task_manager.cancel_task(proxy.task) + # Remove the observer. + self._observers.remove(observer) async def start(self): """Starts all proxy observer tasks.""" @@ -52,23 +67,27 @@ class TaskObserver(BaseObserver): async def stop(self): """Stops all proxy observer tasks.""" - for proxy in self._proxies: + for proxy in self._proxies.values(): await self._task_manager.cancel_task(proxy.task) async def on_push_frame(self, data: FramePushed): - for proxy in self._proxies: + for proxy in self._proxies.values(): await proxy.queue.put(data) - def _create_proxies(self, observers) -> List[Proxy]: - proxies = [] + def _create_proxy(self, observer: BaseObserver) -> Proxy: + queue = asyncio.Queue() + task = self._task_manager.create_task( + self._proxy_task_handler(queue, observer), + f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler", + ) + proxy = Proxy(queue=queue, task=task, observer=observer) + return proxy + + def _create_proxies(self, observers: List[BaseObserver]) -> Dict[BaseObserver, Proxy]: + proxies = {} for observer in observers: - queue = asyncio.Queue() - task = self._task_manager.create_task( - self._proxy_task_handler(queue, observer), - f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler", - ) - proxy = Proxy(queue=queue, task=task, observer=observer) - proxies.append(proxy) + proxy = self._create_proxy(observer) + proxies[observer] = proxy return proxies async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e02f4012d..9bc737b1a 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -8,13 +8,8 @@ import asyncio import time import unittest -from pipecat.frames.frames import ( - EndFrame, - HeartbeatFrame, - StartFrame, - StopFrame, - TextFrame, -) +from pipecat.frames.frames import EndFrame, HeartbeatFrame, StartFrame, StopFrame, TextFrame +from pipecat.observers.base_observer import BaseObserver, FramePushed from pipecat.pipeline.parallel_pipeline import ParallelPipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.task import PipelineParams, PipelineTask @@ -101,6 +96,70 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase): await task.run() assert task.has_finished() + async def test_task_observers(self): + frame_received = False + + class CustomObserver(BaseObserver): + async def on_push_frame(self, data: FramePushed): + nonlocal frame_received + + if isinstance(data.frame, TextFrame): + frame_received = True + + identity = IdentityFilter() + pipeline = Pipeline([identity]) + task = PipelineTask(pipeline, observers=[CustomObserver()]) + task.set_event_loop(asyncio.get_event_loop()) + + await task.queue_frames([TextFrame(text="Hello Downstream!"), EndFrame()]) + await task.run() + assert frame_received + + async def test_task_add_observer(self): + frame_received = False + frame_add_count = 0 + + class CustomObserver(BaseObserver): + async def on_push_frame(self, data: FramePushed): + nonlocal frame_received + + if isinstance(data.frame, TextFrame): + frame_received = True + + class CustomAddObserver(BaseObserver): + async def on_push_frame(self, data: FramePushed): + nonlocal frame_add_count + + if isinstance(data.source, IdentityFilter) and isinstance(data.frame, TextFrame): + frame_add_count += 1 + + identity = IdentityFilter() + pipeline = Pipeline([identity]) + task = PipelineTask(pipeline, observers=[CustomObserver()]) + task.set_event_loop(asyncio.get_event_loop()) + + async def delayed_add_observer(): + observer = CustomAddObserver() + # Wait after the pipeline is started and add an observer. + await asyncio.sleep(0.1) + await task.add_observer(observer) + # Push a TextFrame and wait for the observer to pick it up. + await task.queue_frame(TextFrame(text="Hello Downstream!")) + await asyncio.sleep(0.1) + # Remove the observer + await task.remove_observer(observer) + # Push another TextFrame. This time the counter should not + # increments since we have removed the observer. + await task.queue_frame(TextFrame(text="Hello Downstream!")) + await asyncio.sleep(0.1) + # Finally end the pipeline. + await task.queue_frame(EndFrame()) + + await asyncio.gather(task.run(), delayed_add_observer()) + + assert frame_received + assert frame_add_count == 1 + async def test_task_started_ended_event_handler(self): start_received = False end_received = False