PipelineTask: add add_observer() and remove_observer()
This commit is contained in:
@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `PipelineTask.add_observer()` and `PipelineTask.remove_observer()` to
|
||||
allow mangaging observers at runtime. This is useful for cases where the task
|
||||
is passed around to other code components that might want to observe the
|
||||
pipeline dynamically.
|
||||
|
||||
- Added `user_id` field to `TranscriptionMessage`. This allows identifying the
|
||||
user in a multi-user scenario. Note that this requires that
|
||||
`TranscriptionFrame` has the `user_id` properly set.
|
||||
|
||||
@@ -225,14 +225,16 @@ class PipelineTask(BaseTask):
|
||||
)
|
||||
observers = self._params.observers
|
||||
observers = observers or []
|
||||
self._turn_tracking_observer: Optional[TurnTrackingObserver] = None
|
||||
self._turn_trace_observer: Optional[TurnTraceObserver] = None
|
||||
if self._enable_turn_tracking:
|
||||
self._turn_tracking_observer = TurnTrackingObserver()
|
||||
observers = [self._turn_tracking_observer] + list(observers)
|
||||
if self._enable_turn_tracking and self._enable_tracing:
|
||||
observers.append(self._turn_tracking_observer)
|
||||
if self._enable_tracing and self._turn_tracking_observer:
|
||||
self._turn_trace_observer = TurnTraceObserver(
|
||||
self._turn_tracking_observer, conversation_id=self._conversation_id
|
||||
)
|
||||
observers = [self._turn_trace_observer] + list(observers)
|
||||
observers.append(self._turn_trace_observer)
|
||||
self._finished = False
|
||||
|
||||
# This queue receives frames coming from the pipeline upstream.
|
||||
@@ -297,12 +299,18 @@ class PipelineTask(BaseTask):
|
||||
@property
|
||||
def turn_tracking_observer(self) -> Optional[TurnTrackingObserver]:
|
||||
"""Return the turn tracking observer if enabled."""
|
||||
return getattr(self, "_turn_tracking_observer", None)
|
||||
return self._turn_tracking_observer
|
||||
|
||||
@property
|
||||
def turn_trace_observer(self) -> Optional[TurnTraceObserver]:
|
||||
"""Return the turn trace observer if enabled."""
|
||||
return getattr(self, "_turn_trace_observer", None)
|
||||
return self._turn_trace_observer
|
||||
|
||||
async def add_observer(self, observer: BaseObserver):
|
||||
await self._observer.add_observer(observer)
|
||||
|
||||
async def remove_observer(self, observer: BaseObserver):
|
||||
await self._observer.remove_observer(observer)
|
||||
|
||||
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
self._task_manager.set_event_loop(loop)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from attr import dataclass
|
||||
|
||||
@@ -44,7 +44,22 @@ class TaskObserver(BaseObserver):
|
||||
):
|
||||
self._observers = observers or []
|
||||
self._task_manager = task_manager
|
||||
self._proxies: List[Proxy] = []
|
||||
self._proxies: Dict[BaseObserver, Proxy] = {}
|
||||
|
||||
async def add_observer(self, observer: BaseObserver):
|
||||
proxy = self._create_proxy(observer)
|
||||
self._proxies[observer] = proxy
|
||||
self._observers.append(observer)
|
||||
|
||||
async def remove_observer(self, observer: BaseObserver):
|
||||
if observer in self._proxies:
|
||||
proxy = self._proxies[observer]
|
||||
# Remove the proxy so it doesn't get called anymore.
|
||||
del self._proxies[observer]
|
||||
# Cancel the proxy task right away.
|
||||
await self._task_manager.cancel_task(proxy.task)
|
||||
# Remove the observer.
|
||||
self._observers.remove(observer)
|
||||
|
||||
async def start(self):
|
||||
"""Starts all proxy observer tasks."""
|
||||
@@ -52,23 +67,27 @@ class TaskObserver(BaseObserver):
|
||||
|
||||
async def stop(self):
|
||||
"""Stops all proxy observer tasks."""
|
||||
for proxy in self._proxies:
|
||||
for proxy in self._proxies.values():
|
||||
await self._task_manager.cancel_task(proxy.task)
|
||||
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
for proxy in self._proxies:
|
||||
for proxy in self._proxies.values():
|
||||
await proxy.queue.put(data)
|
||||
|
||||
def _create_proxies(self, observers) -> List[Proxy]:
|
||||
proxies = []
|
||||
def _create_proxy(self, observer: BaseObserver) -> Proxy:
|
||||
queue = asyncio.Queue()
|
||||
task = self._task_manager.create_task(
|
||||
self._proxy_task_handler(queue, observer),
|
||||
f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler",
|
||||
)
|
||||
proxy = Proxy(queue=queue, task=task, observer=observer)
|
||||
return proxy
|
||||
|
||||
def _create_proxies(self, observers: List[BaseObserver]) -> Dict[BaseObserver, Proxy]:
|
||||
proxies = {}
|
||||
for observer in observers:
|
||||
queue = asyncio.Queue()
|
||||
task = self._task_manager.create_task(
|
||||
self._proxy_task_handler(queue, observer),
|
||||
f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler",
|
||||
)
|
||||
proxy = Proxy(queue=queue, task=task, observer=observer)
|
||||
proxies.append(proxy)
|
||||
proxy = self._create_proxy(observer)
|
||||
proxies[observer] = proxy
|
||||
return proxies
|
||||
|
||||
async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver):
|
||||
|
||||
@@ -8,13 +8,8 @@ import asyncio
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
HeartbeatFrame,
|
||||
StartFrame,
|
||||
StopFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.frames.frames import EndFrame, HeartbeatFrame, StartFrame, StopFrame, TextFrame
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -101,6 +96,70 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
await task.run()
|
||||
assert task.has_finished()
|
||||
|
||||
async def test_task_observers(self):
|
||||
frame_received = False
|
||||
|
||||
class CustomObserver(BaseObserver):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
nonlocal frame_received
|
||||
|
||||
if isinstance(data.frame, TextFrame):
|
||||
frame_received = True
|
||||
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline, observers=[CustomObserver()])
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
await task.queue_frames([TextFrame(text="Hello Downstream!"), EndFrame()])
|
||||
await task.run()
|
||||
assert frame_received
|
||||
|
||||
async def test_task_add_observer(self):
|
||||
frame_received = False
|
||||
frame_add_count = 0
|
||||
|
||||
class CustomObserver(BaseObserver):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
nonlocal frame_received
|
||||
|
||||
if isinstance(data.frame, TextFrame):
|
||||
frame_received = True
|
||||
|
||||
class CustomAddObserver(BaseObserver):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
nonlocal frame_add_count
|
||||
|
||||
if isinstance(data.source, IdentityFilter) and isinstance(data.frame, TextFrame):
|
||||
frame_add_count += 1
|
||||
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline, observers=[CustomObserver()])
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
async def delayed_add_observer():
|
||||
observer = CustomAddObserver()
|
||||
# Wait after the pipeline is started and add an observer.
|
||||
await asyncio.sleep(0.1)
|
||||
await task.add_observer(observer)
|
||||
# Push a TextFrame and wait for the observer to pick it up.
|
||||
await task.queue_frame(TextFrame(text="Hello Downstream!"))
|
||||
await asyncio.sleep(0.1)
|
||||
# Remove the observer
|
||||
await task.remove_observer(observer)
|
||||
# Push another TextFrame. This time the counter should not
|
||||
# increments since we have removed the observer.
|
||||
await task.queue_frame(TextFrame(text="Hello Downstream!"))
|
||||
await asyncio.sleep(0.1)
|
||||
# Finally end the pipeline.
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
await asyncio.gather(task.run(), delayed_add_observer())
|
||||
|
||||
assert frame_received
|
||||
assert frame_add_count == 1
|
||||
|
||||
async def test_task_started_ended_event_handler(self):
|
||||
start_received = False
|
||||
end_received = False
|
||||
|
||||
Reference in New Issue
Block a user