From 9bbce225ce0129a0efd504e07b2076317a937345 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 21 May 2025 15:16:48 -0700 Subject: [PATCH] BaseObserver: inherit from BaseObject so we can have events --- src/pipecat/observers/base_observer.py | 5 +-- .../observers/loggers/debug_log_observer.py | 3 ++ .../observers/turn_tracking_observer.py | 31 +------------------ src/pipecat/pipeline/task_observer.py | 9 ++++-- src/pipecat/processors/frameworks/rtvi.py | 6 ++-- src/pipecat/tests/utils.py | 2 ++ .../utils/tracing/turn_trace_observer.py | 6 ++-- 7 files changed, 24 insertions(+), 38 deletions(-) diff --git a/src/pipecat/observers/base_observer.py b/src/pipecat/observers/base_observer.py index f1a0c2a1b..077f6986b 100644 --- a/src/pipecat/observers/base_observer.py +++ b/src/pipecat/observers/base_observer.py @@ -4,12 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from abc import ABC, abstractmethod +from abc import abstractmethod from dataclasses import dataclass from typing_extensions import TYPE_CHECKING from pipecat.frames.frames import Frame +from pipecat.utils.base_object import BaseObject if TYPE_CHECKING: from pipecat.processors.frame_processor import FrameDirection, FrameProcessor @@ -39,7 +40,7 @@ class FramePushed: timestamp: int -class BaseObserver(ABC): +class BaseObserver(BaseObject): """This is the base class for pipeline frame observers. Observers can view all the frames that go through the pipeline without the need to inject processors in the pipeline. This can be useful, for example, to implement diff --git a/src/pipecat/observers/loggers/debug_log_observer.py b/src/pipecat/observers/loggers/debug_log_observer.py index 18048890e..1b75a3f7a 100644 --- a/src/pipecat/observers/loggers/debug_log_observer.py +++ b/src/pipecat/observers/loggers/debug_log_observer.py @@ -74,6 +74,7 @@ class DebugLogObserver(BaseObserver): Union[Tuple[Type[Frame], ...], Dict[Type[Frame], Optional[Tuple[Type, FrameEndpoint]]]] ] = None, exclude_fields: Optional[Set[str]] = None, + **kwargs, ): """Initialize the debug log observer. @@ -87,6 +88,8 @@ class DebugLogObserver(BaseObserver): exclude_fields: Set of field names to exclude from logging. If None, only binary data fields are excluded. """ + super().__init__(**kwargs) + # Process frame filters self.frame_filters = {} diff --git a/src/pipecat/observers/turn_tracking_observer.py b/src/pipecat/observers/turn_tracking_observer.py index 99abdaff6..956e46b55 100644 --- a/src/pipecat/observers/turn_tracking_observer.py +++ b/src/pipecat/observers/turn_tracking_observer.py @@ -30,7 +30,7 @@ class TurnTrackingObserver(BaseObserver): """ def __init__(self, max_frames=100, turn_end_timeout_secs=2.5, **kwargs): - super().__init__() + super().__init__(**kwargs) self._turn_count = 0 self._is_turn_active = False self._is_bot_speaking = False @@ -154,32 +154,3 @@ class TurnTrackingObserver(BaseObserver): status = "interrupted" if was_interrupted else "completed" logger.trace(f"Turn {self._turn_count} {status} after {duration:.2f}s") await self._call_event_handler("on_turn_ended", self._turn_count, duration, was_interrupted) - - def _register_event_handler(self, event_name): - """Register an event handler.""" - if not hasattr(self, "_event_handlers"): - self._event_handlers = {} - if event_name not in self._event_handlers: - self._event_handlers[event_name] = [] - - async def _call_event_handler(self, event_name, *args, **kwargs): - """Call registered event handlers.""" - if not hasattr(self, "_event_handlers"): - return - - if event_name in self._event_handlers: - for handler in self._event_handlers[event_name]: - await handler(self, *args, **kwargs) - - def event_handler(self, event_name): - """Decorator for registering event handlers.""" - - def decorator(func): - if not hasattr(self, "_event_handlers"): - self._event_handlers = {} - if event_name not in self._event_handlers: - self._event_handlers[event_name] = [] - self._event_handlers[event_name].append(func) - return func - - return decorator diff --git a/src/pipecat/pipeline/task_observer.py b/src/pipecat/pipeline/task_observer.py index c4d3d9a26..b7a54f58d 100644 --- a/src/pipecat/pipeline/task_observer.py +++ b/src/pipecat/pipeline/task_observer.py @@ -40,8 +40,13 @@ class TaskObserver(BaseObserver): """ def __init__( - self, *, observers: Optional[List[BaseObserver]] = None, task_manager: BaseTaskManager + self, + *, + observers: Optional[List[BaseObserver]] = None, + task_manager: BaseTaskManager, + **kwargs, ): + super().__init__(**kwargs) self._observers = observers or [] self._task_manager = task_manager self._proxies: Dict[BaseObserver, Proxy] = {} @@ -78,7 +83,7 @@ class TaskObserver(BaseObserver): queue = asyncio.Queue() task = self._task_manager.create_task( self._proxy_task_handler(queue, observer), - f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler", + f"TaskObserver::{observer}::_proxy_task_handler", ) proxy = Proxy(queue=queue, task=task, observer=observer) return proxy diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index f759102ae..449bf3e33 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -437,8 +437,10 @@ class RTVIObserver(BaseObserver): params (RTVIObserverParams): Settings to enable/disable specific messages. """ - def __init__(self, rtvi: "RTVIProcessor", *, params: Optional[RTVIObserverParams] = None): - super().__init__() + def __init__( + self, rtvi: "RTVIProcessor", *, params: Optional[RTVIObserverParams] = None, **kwargs + ): + super().__init__(**kwargs) self._rtvi = rtvi self._params = params or RTVIObserverParams() self._bot_transcription = "" diff --git a/src/pipecat/tests/utils.py b/src/pipecat/tests/utils.py index afa5a9949..3ea52bf26 100644 --- a/src/pipecat/tests/utils.py +++ b/src/pipecat/tests/utils.py @@ -38,7 +38,9 @@ class HeartbeatsObserver(BaseObserver): *, target: FrameProcessor, heartbeat_callback: Callable[[FrameProcessor, HeartbeatFrame], Awaitable[None]], + **kwargs, ): + super().__init__(**kwargs) self._target = target self._callback = heartbeat_callback diff --git a/src/pipecat/utils/tracing/turn_trace_observer.py b/src/pipecat/utils/tracing/turn_trace_observer.py index 4036d6b0f..26fc5b38a 100644 --- a/src/pipecat/utils/tracing/turn_trace_observer.py +++ b/src/pipecat/utils/tracing/turn_trace_observer.py @@ -34,8 +34,10 @@ class TurnTraceObserver(BaseObserver): conversation span that encapsulates the entire session. """ - def __init__(self, turn_tracker: TurnTrackingObserver, conversation_id: Optional[str] = None): - super().__init__() + def __init__( + self, turn_tracker: TurnTrackingObserver, conversation_id: Optional[str] = None, **kwargs + ): + super().__init__(**kwargs) self._turn_tracker = turn_tracker self._current_span: Optional["Span"] = None self._current_turn_number: int = 0