Merge pull request #1866 from pipecat-ai/aleix/base-observers-are-base-objects

BaseObserver: inherit from BaseObject so we can have events
This commit is contained in:
Aleix Conchillo Flaqué
2025-05-21 16:07:38 -07:00
committed by GitHub
7 changed files with 24 additions and 38 deletions

View File

@@ -4,12 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from abc import ABC, abstractmethod
from abc import abstractmethod
from dataclasses import dataclass
from typing_extensions import TYPE_CHECKING
from pipecat.frames.frames import Frame
from pipecat.utils.base_object import BaseObject
if TYPE_CHECKING:
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -39,7 +40,7 @@ class FramePushed:
timestamp: int
class BaseObserver(ABC):
class BaseObserver(BaseObject):
"""This is the base class for pipeline frame observers. Observers can view
all the frames that go through the pipeline without the need to inject
processors in the pipeline. This can be useful, for example, to implement

View File

@@ -74,6 +74,7 @@ class DebugLogObserver(BaseObserver):
Union[Tuple[Type[Frame], ...], Dict[Type[Frame], Optional[Tuple[Type, FrameEndpoint]]]]
] = None,
exclude_fields: Optional[Set[str]] = None,
**kwargs,
):
"""Initialize the debug log observer.
@@ -87,6 +88,8 @@ class DebugLogObserver(BaseObserver):
exclude_fields: Set of field names to exclude from logging. If None, only binary
data fields are excluded.
"""
super().__init__(**kwargs)
# Process frame filters
self.frame_filters = {}

View File

@@ -30,7 +30,7 @@ class TurnTrackingObserver(BaseObserver):
"""
def __init__(self, max_frames=100, turn_end_timeout_secs=2.5, **kwargs):
super().__init__()
super().__init__(**kwargs)
self._turn_count = 0
self._is_turn_active = False
self._is_bot_speaking = False
@@ -154,32 +154,3 @@ class TurnTrackingObserver(BaseObserver):
status = "interrupted" if was_interrupted else "completed"
logger.trace(f"Turn {self._turn_count} {status} after {duration:.2f}s")
await self._call_event_handler("on_turn_ended", self._turn_count, duration, was_interrupted)
def _register_event_handler(self, event_name):
"""Register an event handler."""
if not hasattr(self, "_event_handlers"):
self._event_handlers = {}
if event_name not in self._event_handlers:
self._event_handlers[event_name] = []
async def _call_event_handler(self, event_name, *args, **kwargs):
"""Call registered event handlers."""
if not hasattr(self, "_event_handlers"):
return
if event_name in self._event_handlers:
for handler in self._event_handlers[event_name]:
await handler(self, *args, **kwargs)
def event_handler(self, event_name):
"""Decorator for registering event handlers."""
def decorator(func):
if not hasattr(self, "_event_handlers"):
self._event_handlers = {}
if event_name not in self._event_handlers:
self._event_handlers[event_name] = []
self._event_handlers[event_name].append(func)
return func
return decorator

View File

@@ -40,8 +40,13 @@ class TaskObserver(BaseObserver):
"""
def __init__(
self, *, observers: Optional[List[BaseObserver]] = None, task_manager: BaseTaskManager
self,
*,
observers: Optional[List[BaseObserver]] = None,
task_manager: BaseTaskManager,
**kwargs,
):
super().__init__(**kwargs)
self._observers = observers or []
self._task_manager = task_manager
self._proxies: Dict[BaseObserver, Proxy] = {}
@@ -78,7 +83,7 @@ class TaskObserver(BaseObserver):
queue = asyncio.Queue()
task = self._task_manager.create_task(
self._proxy_task_handler(queue, observer),
f"TaskObserver::{observer.__class__.__name__}::_proxy_task_handler",
f"TaskObserver::{observer}::_proxy_task_handler",
)
proxy = Proxy(queue=queue, task=task, observer=observer)
return proxy

View File

@@ -437,8 +437,10 @@ class RTVIObserver(BaseObserver):
params (RTVIObserverParams): Settings to enable/disable specific messages.
"""
def __init__(self, rtvi: "RTVIProcessor", *, params: Optional[RTVIObserverParams] = None):
super().__init__()
def __init__(
self, rtvi: "RTVIProcessor", *, params: Optional[RTVIObserverParams] = None, **kwargs
):
super().__init__(**kwargs)
self._rtvi = rtvi
self._params = params or RTVIObserverParams()
self._bot_transcription = ""

View File

@@ -38,7 +38,9 @@ class HeartbeatsObserver(BaseObserver):
*,
target: FrameProcessor,
heartbeat_callback: Callable[[FrameProcessor, HeartbeatFrame], Awaitable[None]],
**kwargs,
):
super().__init__(**kwargs)
self._target = target
self._callback = heartbeat_callback

View File

@@ -34,8 +34,10 @@ class TurnTraceObserver(BaseObserver):
conversation span that encapsulates the entire session.
"""
def __init__(self, turn_tracker: TurnTrackingObserver, conversation_id: Optional[str] = None):
super().__init__()
def __init__(
self, turn_tracker: TurnTrackingObserver, conversation_id: Optional[str] = None, **kwargs
):
super().__init__(**kwargs)
self._turn_tracker = turn_tracker
self._current_span: Optional["Span"] = None
self._current_turn_number: int = 0