Merge pull request #1866 from pipecat-ai/aleix/base-observers-are-base-objects
BaseObserver: inherit from BaseObject so we can have events
This commit is contained in:
@@ -4,12 +4,13 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import TYPE_CHECKING
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
@@ -39,7 +40,7 @@ class FramePushed:
|
||||
timestamp: int
|
||||
|
||||
|
||||
class BaseObserver(ABC):
|
||||
class BaseObserver(BaseObject):
|
||||
"""This is the base class for pipeline frame observers. Observers can view
|
||||
all the frames that go through the pipeline without the need to inject
|
||||
processors in the pipeline. This can be useful, for example, to implement
|
||||
|
||||
@@ -74,6 +74,7 @@ class DebugLogObserver(BaseObserver):
|
||||
Union[Tuple[Type[Frame], ...], Dict[Type[Frame], Optional[Tuple[Type, FrameEndpoint]]]]
|
||||
] = None,
|
||||
exclude_fields: Optional[Set[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the debug log observer.
|
||||
|
||||
@@ -87,6 +88,8 @@ class DebugLogObserver(BaseObserver):
|
||||
exclude_fields: Set of field names to exclude from logging. If None, only binary
|
||||
data fields are excluded.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Process frame filters
|
||||
self.frame_filters = {}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class TurnTrackingObserver(BaseObserver):
|
||||
"""
|
||||
|
||||
def __init__(self, max_frames=100, turn_end_timeout_secs=2.5, **kwargs):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self._turn_count = 0
|
||||
self._is_turn_active = False
|
||||
self._is_bot_speaking = False
|
||||
@@ -154,32 +154,3 @@ class TurnTrackingObserver(BaseObserver):
|
||||
status = "interrupted" if was_interrupted else "completed"
|
||||
logger.trace(f"Turn {self._turn_count} {status} after {duration:.2f}s")
|
||||
await self._call_event_handler("on_turn_ended", self._turn_count, duration, was_interrupted)
|
||||
|
||||
def _register_event_handler(self, event_name):
|
||||
"""Register an event handler."""
|
||||
if not hasattr(self, "_event_handlers"):
|
||||
self._event_handlers = {}
|
||||
if event_name not in self._event_handlers:
|
||||
self._event_handlers[event_name] = []
|
||||
|
||||
async def _call_event_handler(self, event_name, *args, **kwargs):
|
||||
"""Call registered event handlers."""
|
||||
if not hasattr(self, "_event_handlers"):
|
||||
return
|
||||
|
||||
if event_name in self._event_handlers:
|
||||
for handler in self._event_handlers[event_name]:
|
||||
await handler(self, *args, **kwargs)
|
||||
|
||||
def event_handler(self, event_name):
|
||||
"""Decorator for registering event handlers."""
|
||||
|
||||
def decorator(func):
|
||||
if not hasattr(self, "_event_handlers"):
|
||||
self._event_handlers = {}
|
||||
if event_name not in self._event_handlers:
|
||||
self._event_handlers[event_name] = []
|
||||
self._event_handlers[event_name].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -40,8 +40,13 @@ class TaskObserver(BaseObserver):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, observers: Optional[List[BaseObserver]] = None, task_manager: BaseTaskManager
|
||||
self,
|
||||
*,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
task_manager: BaseTaskManager,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._observers = observers or []
|
||||
self._task_manager = task_manager
|
||||
self._proxies: Dict[BaseObserver, Proxy] = {}
|
||||
@@ -78,7 +83,7 @@ class TaskObserver(BaseObserver):
|
||||
queue = asyncio.Queue()
|
||||
task = self._task_manager.create_task(
|
||||
self._proxy_task_handler(queue, observer),
|
||||
f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler",
|
||||
f"TaskObserver::{observer}::_proxy_task_handler",
|
||||
)
|
||||
proxy = Proxy(queue=queue, task=task, observer=observer)
|
||||
return proxy
|
||||
|
||||
@@ -437,8 +437,10 @@ class RTVIObserver(BaseObserver):
|
||||
params (RTVIObserverParams): Settings to enable/disable specific messages.
|
||||
"""
|
||||
|
||||
def __init__(self, rtvi: "RTVIProcessor", *, params: Optional[RTVIObserverParams] = None):
|
||||
super().__init__()
|
||||
def __init__(
|
||||
self, rtvi: "RTVIProcessor", *, params: Optional[RTVIObserverParams] = None, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._rtvi = rtvi
|
||||
self._params = params or RTVIObserverParams()
|
||||
self._bot_transcription = ""
|
||||
|
||||
@@ -38,7 +38,9 @@ class HeartbeatsObserver(BaseObserver):
|
||||
*,
|
||||
target: FrameProcessor,
|
||||
heartbeat_callback: Callable[[FrameProcessor, HeartbeatFrame], Awaitable[None]],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._target = target
|
||||
self._callback = heartbeat_callback
|
||||
|
||||
|
||||
@@ -34,8 +34,10 @@ class TurnTraceObserver(BaseObserver):
|
||||
conversation span that encapsulates the entire session.
|
||||
"""
|
||||
|
||||
def __init__(self, turn_tracker: TurnTrackingObserver, conversation_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
def __init__(
|
||||
self, turn_tracker: TurnTrackingObserver, conversation_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._turn_tracker = turn_tracker
|
||||
self._current_span: Optional["Span"] = None
|
||||
self._current_turn_number: int = 0
|
||||
|
||||
Reference in New Issue
Block a user