diff --git a/changelog/3881.added.2.md b/changelog/3881.added.2.md new file mode 100644 index 000000000..a5bda94c1 --- /dev/null +++ b/changelog/3881.added.2.md @@ -0,0 +1 @@ +- Added `ClientConnectedFrame`, a new `SystemFrame` pushed by all transports (Daily, LiveKit, FastAPI WebSocket, WebSocket Server, SmallWebRTC, HeyGen, Tavus) when a client connects. Enables observers to track transport readiness timing. diff --git a/changelog/3881.added.3.md b/changelog/3881.added.3.md new file mode 100644 index 000000000..cad26e876 --- /dev/null +++ b/changelog/3881.added.3.md @@ -0,0 +1 @@ +Added `BotConnectedFrame` for SFU transports and `on_transport_timing_report` event to `StartupTimingObserver` with bot and client connection timing. diff --git a/changelog/3881.added.md b/changelog/3881.added.md new file mode 100644 index 000000000..cbf6d0293 --- /dev/null +++ b/changelog/3881.added.md @@ -0,0 +1 @@ +- Added `StartupTimingObserver` for measuring how long each processor's `start()` method takes during pipeline startup. Also measures transport readiness — the time from `StartFrame` to first client connection — via the `on_transport_readiness_measured` event. Useful for diagnosing cold start slowness and identifying initialization bottlenecks. diff --git a/examples/foundational/29-turn-tracking-observer.py b/examples/foundational/29-turn-tracking-observer.py index 321197db2..4af28f1ed 100644 --- a/examples/foundational/29-turn-tracking-observer.py +++ b/examples/foundational/29-turn-tracking-observer.py @@ -12,6 +12,7 @@ from loguru import logger from pipecat.audio.vad.silero import SileroVADAnalyzer from pipecat.frames.frames import LLMRunFrame +from pipecat.observers.startup_timing_observer import StartupTimingObserver from pipecat.observers.user_bot_latency_observer import UserBotLatencyObserver from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -87,8 +88,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): ] ) - # Create latency tracking observer latency_observer = UserBotLatencyObserver() + startup_observer = StartupTimingObserver() task = PipelineTask( pipeline, @@ -97,14 +98,25 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): enable_usage_metrics=True, ), idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, - observers=[latency_observer], + observers=[latency_observer, startup_observer], ) - # Log latency measurements using the event handler @latency_observer.event_handler("on_latency_measured") async def on_latency_measured(observer, latency_seconds): logger.info(f"⏱️ User-to-bot latency: {latency_seconds:.3f}s") + @startup_observer.event_handler("on_startup_timing_report") + async def on_startup_timing_report(observer, report): + logger.info(f"Total startup: {report.total_duration_secs:.3f}s") + for timing in report.processor_timings: + logger.info(f" {timing.processor_name}: {timing.duration_secs:.3f}s") + + @startup_observer.event_handler("on_transport_timing_report") + async def on_transport_timing_report(observer, report): + if report.bot_connected_secs is not None: + logger.info(f"Bot connected: {report.bot_connected_secs:.3f}s") + logger.info(f"Client connected: {report.client_connected_secs:.3f}s") + turn_observer = task.turn_tracking_observer if turn_observer: diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 126f3c001..86778e564 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -1910,6 +1910,29 @@ class StopFrame(ControlFrame, UninterruptibleFrame): pass +@dataclass +class BotConnectedFrame(SystemFrame): + """Frame indicating the bot has connected to the transport service. + + Pushed downstream by SFU transports (Daily, LiveKit, HeyGen, Tavus) + when the bot successfully joins the room. Non-SFU transports do not + emit this frame. + """ + + pass + + +@dataclass +class ClientConnectedFrame(SystemFrame): + """Frame indicating that a client has connected to the transport. + + Pushed downstream by the input transport when a client (participant) + connects. Used by observers to measure transport readiness timing. + """ + + pass + + @dataclass class OutputTransportReadyFrame(ControlFrame): """Frame indicating that the output transport is ready. diff --git a/src/pipecat/observers/base_observer.py b/src/pipecat/observers/base_observer.py index 78e36fec8..70c79224a 100644 --- a/src/pipecat/observers/base_observer.py +++ b/src/pipecat/observers/base_observer.py @@ -100,3 +100,11 @@ class BaseObserver(BaseObject): data: The event data containing details about the frame transfer. """ pass + + async def on_pipeline_started(self): + """Called when the pipeline has fully started. + + Fired after the ``StartFrame`` has been processed by all processors + in the pipeline, including nested ``ParallelPipeline`` branches. + """ + pass diff --git a/src/pipecat/observers/startup_timing_observer.py b/src/pipecat/observers/startup_timing_observer.py new file mode 100644 index 000000000..a1ea04d47 --- /dev/null +++ b/src/pipecat/observers/startup_timing_observer.py @@ -0,0 +1,328 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Observer for tracking pipeline startup timing. + +This module provides an observer that measures how long each processor's +``start()`` method takes during pipeline startup. It works by tracking +when a ``StartFrame`` arrives at a processor (``on_process_frame``) versus +when it leaves (``on_push_frame``), giving the exact ``start()`` duration +for each processor in the pipeline. + +It also measures transport timing — the time from ``StartFrame`` to the +first ``BotConnectedFrame`` (SFU transports only) and ``ClientConnectedFrame`` +— via a separate ``on_transport_timing_report`` event. + +Example:: + + observer = StartupTimingObserver() + + @observer.event_handler("on_startup_timing_report") + async def on_report(observer, report): + for t in report.processor_timings: + print(f"{t.processor_name}: {t.duration_secs:.3f}s") + + @observer.event_handler("on_transport_timing_report") + async def on_transport(observer, report): + if report.bot_connected_secs is not None: + print(f"Bot connected in {report.bot_connected_secs:.3f}s") + print(f"Client connected in {report.client_connected_secs:.3f}s") + + task = PipelineTask(pipeline, observers=[observer]) +""" + +import time +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +from pydantic import BaseModel, Field + +from pipecat.frames.frames import BotConnectedFrame, ClientConnectedFrame, StartFrame +from pipecat.observers.base_observer import BaseObserver, FrameProcessed, FramePushed +from pipecat.pipeline.base_pipeline import BasePipeline +from pipecat.pipeline.pipeline import PipelineSource +from pipecat.processors.frame_processor import FrameProcessor + +# Internal pipeline types excluded from tracking by default. +_INTERNAL_TYPES = (PipelineSource, BasePipeline) + + +@dataclass +class _ArrivalInfo: + """Internal record of when a StartFrame arrived at a processor.""" + + processor: FrameProcessor + arrival_ts_ns: int + + +class ProcessorStartupTiming(BaseModel): + """Startup timing for a single processor. + + Parameters: + processor_name: The name of the processor. + start_offset_secs: Offset in seconds from the StartFrame to when this + processor's start() began. + duration_secs: How long the processor's start() took, in seconds. + """ + + processor_name: str + start_offset_secs: float + duration_secs: float + + +class StartupTimingReport(BaseModel): + """Report of startup timings for all measured processors. + + Parameters: + start_time: Unix timestamp when the first processor began starting. + total_duration_secs: Total wall-clock time from first to last processor start. + processor_timings: Per-processor timing data, in pipeline order. + """ + + start_time: float + total_duration_secs: float + processor_timings: List[ProcessorStartupTiming] = Field(default_factory=list) + + +class TransportTimingReport(BaseModel): + """Time from pipeline start to transport connection milestones. + + Parameters: + start_time: Unix timestamp of the StartFrame (pipeline start). + bot_connected_secs: Seconds from StartFrame to first BotConnectedFrame + (only set for SFU transports). + client_connected_secs: Seconds from StartFrame to first ClientConnectedFrame. + """ + + start_time: float + bot_connected_secs: Optional[float] = None + client_connected_secs: Optional[float] = None + + +class StartupTimingObserver(BaseObserver): + """Observer that measures processor startup times during pipeline initialization. + + Tracks how long each processor's ``start()`` method takes by measuring the + time between when a ``StartFrame`` arrives at a processor and when it is + pushed downstream. This captures WebSocket connections, API authentication, + model loading, and other initialization work. + + Also measures transport timing, the time from ``StartFrame`` to connection + milestones: + + - ``bot_connected_secs``: When the bot joins the transport room + (SFU transports only, triggered by ``BotConnectedFrame``). + - ``client_connected_secs``: When a remote participant connects + (triggered by ``ClientConnectedFrame``). + + By default, internal pipeline processors (``PipelineSource``, ``Pipeline``) + are excluded from the report. Pass ``processor_types`` to measure only + specific types. + + Event handlers available: + + - on_startup_timing_report: Called once after startup completes with the full + timing report. + - on_transport_timing_report: Called once when the first client connects with a + TransportTimingReport containing client_connected_secs and bot_connected_secs + (if available). + + Example:: + + observer = StartupTimingObserver( + processor_types=(STTService, TTSService) + ) + + @observer.event_handler("on_startup_timing_report") + async def on_report(observer, report): + for t in report.processor_timings: + logger.info(f"{t.processor_name}: {t.duration_secs:.3f}s") + + @observer.event_handler("on_transport_timing_report") + async def on_transport(observer, report): + if report.bot_connected_secs is not None: + logger.info(f"Bot connected in {report.bot_connected_secs:.3f}s") + logger.info(f"Client connected in {report.client_connected_secs:.3f}s") + + task = PipelineTask(pipeline, observers=[observer]) + + Args: + processor_types: Optional tuple of processor types to measure. If None, + all non-internal processors are measured. + """ + + def __init__( + self, + *, + processor_types: Optional[Tuple[Type[FrameProcessor], ...]] = None, + **kwargs, + ): + """Initialize the startup timing observer. + + Args: + processor_types: Optional tuple of processor types to measure. + If None, all non-internal processors are measured. + **kwargs: Additional arguments passed to parent class. + """ + super().__init__(**kwargs) + self._processor_types = processor_types + + # Map processor ID -> arrival info. + self._arrivals: Dict[int, _ArrivalInfo] = {} + + # Collected timings in pipeline order. + self._timings: List[ProcessorStartupTiming] = [] + + # Lock onto the first StartFrame we see (by frame ID). + self._start_frame_id: Optional[str] = None + + # Whether we've already emitted the startup timing report. + self._startup_timing_reported = False + + # Whether we've already measured transport timing. + self._transport_timing_reported = False + + # Timestamp (ns) when we first see a StartFrame arrive at a processor. + self._start_frame_arrival_ns: Optional[int] = None + + # Bot connected timing (stored for inclusion in the transport report). + self._bot_connected_secs: Optional[float] = None + + # Wall clock time when the StartFrame was first seen. + self._start_wall_clock: Optional[float] = None + + self._register_event_handler("on_startup_timing_report") + self._register_event_handler("on_transport_timing_report") + + def _should_track(self, processor: FrameProcessor) -> bool: + """Check if a processor should be tracked for timing. + + Args: + processor: The processor to check. + + Returns: + True if the processor matches the filter or no filter is set. + """ + if self._processor_types is not None: + return isinstance(processor, self._processor_types) + # Default: exclude internal pipeline plumbing. + return not isinstance(processor, _INTERNAL_TYPES) + + async def on_pipeline_started(self): + """Emit the startup timing report when the pipeline has fully started. + + Called by the ``PipelineTask`` after the ``StartFrame`` has been + processed by all processors, including nested ``ParallelPipeline`` + branches. + """ + if self._timings: + await self._emit_report() + + async def on_process_frame(self, data: FrameProcessed): + """Record when a StartFrame arrives at a processor. + + Args: + data: The frame processing event data. + """ + if self._startup_timing_reported: + return + + if not isinstance(data.frame, StartFrame): + return + + # Lock onto the first StartFrame. + if self._start_frame_id is None: + self._start_frame_id = data.frame.id + self._start_frame_arrival_ns = data.timestamp + self._start_wall_clock = time.time() + elif data.frame.id != self._start_frame_id: + return + + if self._should_track(data.processor): + self._arrivals[data.processor.id] = _ArrivalInfo( + processor=data.processor, arrival_ts_ns=data.timestamp + ) + + async def on_push_frame(self, data: FramePushed): + """Record when a StartFrame leaves a processor and compute the delta. + + Also handles ``BotConnectedFrame`` and ``ClientConnectedFrame`` to + measure transport timing. + + Args: + data: The frame push event data. + """ + if isinstance(data.frame, BotConnectedFrame): + self._handle_bot_connected(data) + return + + if isinstance(data.frame, ClientConnectedFrame): + await self._handle_client_connected(data) + return + + if self._startup_timing_reported: + return + + if not isinstance(data.frame, StartFrame): + return + + if self._start_frame_id is not None and data.frame.id != self._start_frame_id: + return + + arrival = self._arrivals.pop(data.source.id, None) + if arrival is None: + return + + duration_ns = data.timestamp - arrival.arrival_ts_ns + duration_secs = duration_ns / 1e9 + start_offset_secs = (arrival.arrival_ts_ns - self._start_frame_arrival_ns) / 1e9 + + self._timings.append( + ProcessorStartupTiming( + processor_name=arrival.processor.name, + start_offset_secs=start_offset_secs, + duration_secs=duration_secs, + ) + ) + + def _handle_bot_connected(self, data: FramePushed): + """Record bot connected timing on first BotConnectedFrame.""" + if self._bot_connected_secs is not None or self._start_frame_arrival_ns is None: + return + + delta_ns = data.timestamp - self._start_frame_arrival_ns + self._bot_connected_secs = delta_ns / 1e9 + + async def _handle_client_connected(self, data: FramePushed): + """Emit transport timing report on first ClientConnectedFrame.""" + if self._transport_timing_reported or self._start_frame_arrival_ns is None: + return + + self._transport_timing_reported = True + delta_ns = data.timestamp - self._start_frame_arrival_ns + client_connected_secs = delta_ns / 1e9 + report = TransportTimingReport( + start_time=self._start_wall_clock or 0.0, + bot_connected_secs=self._bot_connected_secs, + client_connected_secs=client_connected_secs, + ) + await self._call_event_handler("on_transport_timing_report", report) + + async def _emit_report(self): + """Build and emit the startup timing report.""" + if self._startup_timing_reported: + return + self._startup_timing_reported = True + + total = sum(t.duration_secs for t in self._timings) + + report = StartupTimingReport( + start_time=self._start_wall_clock or 0.0, + total_duration_secs=total, + processor_timings=self._timings, + ) + + await self._call_event_handler("on_startup_timing_report", report) diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index deae6290c..906d55eb6 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -915,6 +915,7 @@ class PipelineTask(BasePipelineTask): if isinstance(frame, StartFrame): await self._call_event_handler("on_pipeline_started", frame) + await self._observer.on_pipeline_started() # Start heartbeat tasks now that StartFrame has been processed # by all processors in the pipeline diff --git a/src/pipecat/pipeline/task_observer.py b/src/pipecat/pipeline/task_observer.py index 4d33fd60e..dc2040e07 100644 --- a/src/pipecat/pipeline/task_observer.py +++ b/src/pipecat/pipeline/task_observer.py @@ -39,6 +39,12 @@ class Proxy: observer: BaseObserver +class _PipelineStartedSignal: + """Internal sentinel queued to observers when the pipeline has started.""" + + pass + + class TaskObserver(BaseObserver): """Proxy observer that manages multiple observers without blocking the pipeline. @@ -129,6 +135,10 @@ class TaskObserver(BaseObserver): for proxy in self._proxies: await proxy.cleanup() + async def on_pipeline_started(self): + """Forward pipeline started signal to all managed observers.""" + await self._send_to_proxy(_PipelineStartedSignal()) + async def on_process_frame(self, data: FrameProcessed): """Queue frame data for all managed observers. @@ -186,7 +196,9 @@ class TaskObserver(BaseObserver): while True: data = await queue.get() - if isinstance(data, FramePushed): + if isinstance(data, _PipelineStartedSignal): + await observer.on_pipeline_started() + elif isinstance(data, FramePushed): if on_push_frame_deprecated: await observer.on_push_frame( data.source, data.destination, data.frame, data.direction, data.timestamp diff --git a/src/pipecat/services/heygen/client.py b/src/pipecat/services/heygen/client.py index 4018d3858..6d45d6114 100644 --- a/src/pipecat/services/heygen/client.py +++ b/src/pipecat/services/heygen/client.py @@ -62,10 +62,12 @@ class HeyGenCallbacks(BaseModel): """Callback handlers for HeyGen events. Parameters: - on_participant_connected: Called when a participant connects - on_participant_disconnected: Called when a participant disconnects + on_connected: Called when the bot connects to the LiveKit room. + on_participant_connected: Called when a participant connects. + on_participant_disconnected: Called when a participant disconnects. """ + on_connected: Callable[[], Awaitable[None]] on_participant_connected: Callable[[str], Awaitable[None]] on_participant_disconnected: Callable[[str], Awaitable[None]] @@ -251,6 +253,7 @@ class HeyGenClient: logger.debug(f"HeyGenClient send_interval: {self._send_interval}") await self._ws_connect() await self._livekit_connect() + self._call_event_callback(self._callbacks.on_connected) async def stop(self) -> None: """Stop the client and terminate all connections. diff --git a/src/pipecat/services/heygen/video.py b/src/pipecat/services/heygen/video.py index b97f4a5ed..7f3624f35 100644 --- a/src/pipecat/services/heygen/video.py +++ b/src/pipecat/services/heygen/video.py @@ -128,6 +128,7 @@ class HeyGenVideoService(AIService): session_request=self._session_request, service_type=self._service_type, callbacks=HeyGenCallbacks( + on_connected=self._on_connected, on_participant_connected=self._on_participant_connected, on_participant_disconnected=self._on_participant_disconnected, ), @@ -144,6 +145,10 @@ class HeyGenVideoService(AIService): await self._client.cleanup() self._client = None + async def _on_connected(self): + """Handle bot connected to LiveKit room.""" + logger.info("HeyGen bot connected to LiveKit room") + async def _on_participant_connected(self, participant_id: str): """Handle participant connected events.""" logger.info(f"Participant connected {participant_id}") diff --git a/src/pipecat/services/tavus/video.py b/src/pipecat/services/tavus/video.py index d9f259797..8c63ff354 100644 --- a/src/pipecat/services/tavus/video.py +++ b/src/pipecat/services/tavus/video.py @@ -94,6 +94,7 @@ class TavusVideoService(AIService): """ await super().setup(setup) callbacks = TavusCallbacks( + on_joined=self._on_joined, on_participant_joined=self._on_participant_joined, on_participant_left=self._on_participant_left, ) @@ -119,6 +120,10 @@ class TavusVideoService(AIService): await self._client.cleanup() self._client = None + async def _on_joined(self, data): + """Handle bot joined the Daily room.""" + logger.info("Tavus bot joined Daily room") + async def _on_participant_left(self, participant, reason): """Handle participant leaving the session.""" participant_id = participant["id"] diff --git a/src/pipecat/transports/daily/transport.py b/src/pipecat/transports/daily/transport.py index 9575fd51b..dc9868426 100644 --- a/src/pipecat/transports/daily/transport.py +++ b/src/pipecat/transports/daily/transport.py @@ -24,7 +24,9 @@ from pydantic import BaseModel from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams from pipecat.frames.frames import ( + BotConnectedFrame, CancelFrame, + ClientConnectedFrame, DataFrame, EndFrame, Frame, @@ -2070,6 +2072,8 @@ class DailyTransport(BaseTransport): Event handlers available: - on_joined: Called when the bot joins the room. Args: (data: dict) + - on_connected: Called when the bot connects to the room (alias for + on_joined). Args: (data: dict) - on_left: Called when the bot leaves the room. - on_before_leave: [sync] Called just before the bot leaves the room. - on_error: Called when a transport error occurs. Args: (error: str) @@ -2187,6 +2191,7 @@ class DailyTransport(BaseTransport): # Register supported handlers. The user will only be able to register # these handlers. self._register_event_handler("on_active_speaker_changed") + self._register_event_handler("on_connected") self._register_event_handler("on_joined") self._register_event_handler("on_left") self._register_event_handler("on_error") @@ -2578,6 +2583,10 @@ class DailyTransport(BaseTransport): if error: await self._on_error(f"Unable to start transcription: {error}") await self._call_event_handler("on_joined", data) + # Also call on_connected for compatibility with other transports + await self._call_event_handler("on_connected", data) + if self._input: + await self._input.push_frame(BotConnectedFrame()) async def _on_left(self): """Handle room left events.""" @@ -2716,6 +2725,8 @@ class DailyTransport(BaseTransport): await self._call_event_handler("on_participant_joined", participant) # Also call on_client_connected for compatibility with other transports await self._call_event_handler("on_client_connected", participant) + if self._input: + await self._input.push_frame(ClientConnectedFrame()) async def _on_participant_left(self, participant, reason): """Handle participant left events.""" diff --git a/src/pipecat/transports/heygen/transport.py b/src/pipecat/transports/heygen/transport.py index dbeded3e5..d79d0080e 100644 --- a/src/pipecat/transports/heygen/transport.py +++ b/src/pipecat/transports/heygen/transport.py @@ -23,9 +23,11 @@ from loguru import logger from pipecat.frames.frames import ( AudioRawFrame, + BotConnectedFrame, BotStartedSpeakingFrame, BotStoppedSpeakingFrame, CancelFrame, + ClientConnectedFrame, EndFrame, Frame, InputAudioRawFrame, @@ -339,6 +341,7 @@ class HeyGenTransport(BaseTransport): session_request=session_request, service_type=service_type, callbacks=HeyGenCallbacks( + on_connected=self._on_connected, on_participant_connected=self._on_participant_connected, on_participant_disconnected=self._on_participant_disconnected, ), @@ -349,9 +352,16 @@ class HeyGenTransport(BaseTransport): # Register supported handlers. The user will only be able to register # these handlers. + self._register_event_handler("on_connected") self._register_event_handler("on_client_connected") self._register_event_handler("on_client_disconnected") + async def _on_connected(self): + """Handle bot connected to LiveKit room.""" + await self._call_event_handler("on_connected") + if self._input: + await self._input.push_frame(BotConnectedFrame()) + async def _on_participant_disconnected(self, participant_id: str): logger.debug(f"HeyGen participant {participant_id} disconnected") if participant_id != "heygen": @@ -387,6 +397,8 @@ class HeyGenTransport(BaseTransport): async def _on_client_connected(self, participant: Any): """Handle client connected events.""" await self._call_event_handler("on_client_connected", participant) + if self._input: + await self._input.push_frame(ClientConnectedFrame()) async def _on_client_disconnected(self, participant: Any): """Handle client disconnected events.""" diff --git a/src/pipecat/transports/livekit/transport.py b/src/pipecat/transports/livekit/transport.py index 1902e7cd3..7e9c1de35 100644 --- a/src/pipecat/transports/livekit/transport.py +++ b/src/pipecat/transports/livekit/transport.py @@ -23,7 +23,9 @@ from pipecat.audio.utils import create_stream_resampler from pipecat.audio.vad.vad_analyzer import VADAnalyzer from pipecat.frames.frames import ( AudioRawFrame, + BotConnectedFrame, CancelFrame, + ClientConnectedFrame, EndFrame, ImageRawFrame, OutputAudioRawFrame, @@ -1131,6 +1133,8 @@ class LiveKitTransport(BaseTransport): async def _on_connected(self): """Handle room connected events.""" await self._call_event_handler("on_connected") + if self._input: + await self._input.push_frame(BotConnectedFrame()) async def _on_disconnected(self): """Handle room disconnected events.""" @@ -1143,6 +1147,8 @@ class LiveKitTransport(BaseTransport): async def _on_participant_connected(self, participant_id: str): """Handle participant connected events.""" await self._call_event_handler("on_participant_connected", participant_id) + if self._input: + await self._input.push_frame(ClientConnectedFrame()) async def _on_participant_disconnected(self, participant_id: str): """Handle participant disconnected events.""" diff --git a/src/pipecat/transports/smallwebrtc/transport.py b/src/pipecat/transports/smallwebrtc/transport.py index dc91588a3..36f883278 100644 --- a/src/pipecat/transports/smallwebrtc/transport.py +++ b/src/pipecat/transports/smallwebrtc/transport.py @@ -23,6 +23,7 @@ from pydantic import BaseModel from pipecat.frames.frames import ( CancelFrame, + ClientConnectedFrame, EndFrame, Frame, InputAudioRawFrame, @@ -964,6 +965,8 @@ class SmallWebRTCTransport(BaseTransport): async def _on_client_connected(self, webrtc_connection): """Handle client connection events.""" await self._call_event_handler("on_client_connected", webrtc_connection) + if self._input: + await self._input.push_frame(ClientConnectedFrame()) async def _on_client_disconnected(self, webrtc_connection): """Handle client disconnection events.""" diff --git a/src/pipecat/transports/tavus/transport.py b/src/pipecat/transports/tavus/transport.py index dd63cb790..cb6844250 100644 --- a/src/pipecat/transports/tavus/transport.py +++ b/src/pipecat/transports/tavus/transport.py @@ -21,7 +21,9 @@ from loguru import logger from pydantic import BaseModel from pipecat.frames.frames import ( + BotConnectedFrame, CancelFrame, + ClientConnectedFrame, EndFrame, Frame, InputAudioRawFrame, @@ -132,10 +134,12 @@ class TavusCallbacks(BaseModel): """Callback handlers for Tavus events. Parameters: + on_connected: Called when the bot connects to the room. on_participant_joined: Called when a participant joins the conversation. on_participant_left: Called when a participant leaves the conversation. """ + on_connected: Callable[[Mapping[str, Any]], Awaitable[None]] on_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]] on_participant_left: Callable[[Mapping[str, Any], str], Awaitable[None]] @@ -270,6 +274,7 @@ class TavusTransportClient: async def _on_joined(self, data): """Handle joined event.""" logger.debug("TavusTransportClient joined!") + await self._callbacks.on_connected(data) async def _on_left(self): """Handle left event.""" @@ -664,6 +669,7 @@ class TavusTransport(BaseTransport): Event handlers available: + - on_connected(transport, data): Bot connected to the room - on_client_connected(transport, participant): Participant connected to the session - on_client_disconnected(transport, participant): Participant disconnected from the session @@ -702,6 +708,7 @@ class TavusTransport(BaseTransport): self._params = params callbacks = TavusCallbacks( + on_connected=self._on_joined, on_participant_joined=self._on_participant_joined, on_participant_left=self._on_participant_left, ) @@ -720,9 +727,16 @@ class TavusTransport(BaseTransport): # Register supported handlers. The user will only be able to register # these handlers. + self._register_event_handler("on_connected") self._register_event_handler("on_client_connected") self._register_event_handler("on_client_disconnected") + async def _on_joined(self, data): + """Handle bot joined room event.""" + await self._call_event_handler("on_connected", data) + if self._input: + await self._input.push_frame(BotConnectedFrame()) + async def _on_participant_left(self, participant, reason): """Handle participant left events.""" persona_name = await self._client.get_persona_name() @@ -786,6 +800,8 @@ class TavusTransport(BaseTransport): async def _on_client_connected(self, participant: Any): """Handle client connected events.""" await self._call_event_handler("on_client_connected", participant) + if self._input: + await self._input.push_frame(ClientConnectedFrame()) async def _on_client_disconnected(self, participant: Any): """Handle client disconnected events.""" diff --git a/src/pipecat/transports/websocket/fastapi.py b/src/pipecat/transports/websocket/fastapi.py index f52123e52..0fde2b9ae 100644 --- a/src/pipecat/transports/websocket/fastapi.py +++ b/src/pipecat/transports/websocket/fastapi.py @@ -23,6 +23,7 @@ from pydantic import BaseModel from pipecat.frames.frames import ( CancelFrame, + ClientConnectedFrame, EndFrame, Frame, InputAudioRawFrame, @@ -260,6 +261,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): if not self._monitor_websocket_task and self._params.session_timeout: self._monitor_websocket_task = self.create_task(self._monitor_websocket()) await self._client.trigger_client_connected() + await self.push_frame(ClientConnectedFrame()) if not self._receive_task: self._receive_task = self.create_task(self._receive_messages()) await self.set_transport_ready(frame) diff --git a/src/pipecat/transports/websocket/server.py b/src/pipecat/transports/websocket/server.py index e5f628fa4..fa3645d37 100644 --- a/src/pipecat/transports/websocket/server.py +++ b/src/pipecat/transports/websocket/server.py @@ -22,11 +22,11 @@ from pydantic import BaseModel from pipecat.frames.frames import ( CancelFrame, + ClientConnectedFrame, EndFrame, Frame, InputAudioRawFrame, InputTransportMessageFrame, - InputTransportMessageUrgentFrame, InterruptionFrame, OutputAudioRawFrame, OutputTransportMessageFrame, @@ -504,6 +504,8 @@ class WebsocketServerTransport(BaseTransport): if self._output: await self._output.set_client_connection(websocket) await self._call_event_handler("on_client_connected", websocket) + if self._input: + await self._input.push_frame(ClientConnectedFrame()) else: logger.error("A WebsocketServerTransport output is missing in the pipeline") diff --git a/tests/test_startup_timing_observer.py b/tests/test_startup_timing_observer.py new file mode 100644 index 000000000..6355c6081 --- /dev/null +++ b/tests/test_startup_timing_observer.py @@ -0,0 +1,337 @@ +import asyncio +import unittest + +from pipecat.frames.frames import ( + BotConnectedFrame, + ClientConnectedFrame, + Frame, + StartFrame, + TextFrame, +) +from pipecat.observers.startup_timing_observer import ( + StartupTimingObserver, + StartupTimingReport, + TransportTimingReport, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.tests.utils import run_test + + +class SlowStartProcessor(FrameProcessor): + """A processor that sleeps during start to simulate slow initialization.""" + + def __init__(self, delay: float = 0.1, **kwargs): + super().__init__(**kwargs) + self._delay = delay + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if isinstance(frame, StartFrame): + await asyncio.sleep(self._delay) + await self.push_frame(frame, direction) + + +class FastProcessor(FrameProcessor): + """A processor with no start delay.""" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + await self.push_frame(frame, direction) + + +class TestStartupTimingObserver(unittest.IsolatedAsyncioTestCase): + """Tests for StartupTimingObserver.""" + + async def test_timing_reported(self): + """Test that startup timing is measured and reported.""" + observer = StartupTimingObserver() + processor = SlowStartProcessor(delay=0.1) + + reports = [] + + @observer.event_handler("on_startup_timing_report") + async def on_report(obs, report): + reports.append(report) + + frames_to_send = [TextFrame(text="hello")] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[TextFrame], + observers=[observer], + ) + + self.assertEqual(len(reports), 1) + report = reports[0] + self.assertGreater(report.total_duration_secs, 0) + self.assertGreater(len(report.processor_timings), 0) + + # Find our slow processor in the timings. + slow_timings = [ + t for t in report.processor_timings if "SlowStartProcessor" in t.processor_name + ] + self.assertEqual(len(slow_timings), 1) + self.assertGreaterEqual(slow_timings[0].duration_secs, 0.05) + + async def test_processor_types_filter(self): + """Test that processor_types filter limits which processors appear.""" + observer = StartupTimingObserver(processor_types=(SlowStartProcessor,)) + processor = SlowStartProcessor(delay=0.05) + + reports = [] + + @observer.event_handler("on_startup_timing_report") + async def on_report(obs, report): + reports.append(report) + + frames_to_send = [TextFrame(text="hello")] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[TextFrame], + observers=[observer], + ) + + self.assertEqual(len(reports), 1) + report = reports[0] + + # Only SlowStartProcessor should be in the timings. + for t in report.processor_timings: + self.assertIn("SlowStartProcessor", t.processor_name) + + async def test_report_emits_once(self): + """Test that the report is emitted only once even with multiple frames.""" + observer = StartupTimingObserver() + processor = FastProcessor() + + reports = [] + + @observer.event_handler("on_startup_timing_report") + async def on_report(obs, report): + reports.append(report) + + frames_to_send = [ + TextFrame(text="first"), + TextFrame(text="second"), + TextFrame(text="third"), + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[TextFrame, TextFrame, TextFrame], + observers=[observer], + ) + + self.assertEqual(len(reports), 1) + + async def test_event_handler_receives_report(self): + """Test that the event handler receives a proper StartupTimingReport.""" + observer = StartupTimingObserver() + processor = SlowStartProcessor(delay=0.05) + + reports = [] + + @observer.event_handler("on_startup_timing_report") + async def on_report(obs, report): + reports.append(report) + + frames_to_send = [TextFrame(text="hello")] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[TextFrame], + observers=[observer], + ) + + self.assertEqual(len(reports), 1) + report = reports[0] + self.assertIsInstance(report, StartupTimingReport) + self.assertIsInstance(report.total_duration_secs, float) + self.assertGreater(report.start_time, 0) + for timing in report.processor_timings: + self.assertIsInstance(timing.processor_name, str) + self.assertIsInstance(timing.duration_secs, float) + self.assertGreaterEqual(timing.start_offset_secs, 0) + + async def test_excludes_internal_processors(self): + """Test that internal pipeline processors are excluded by default.""" + observer = StartupTimingObserver() + processor = FastProcessor() + + reports = [] + + @observer.event_handler("on_startup_timing_report") + async def on_report(obs, report): + reports.append(report) + + frames_to_send = [TextFrame(text="hello")] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[TextFrame], + observers=[observer], + ) + + self.assertEqual(len(reports), 1) + report = reports[0] + + # No internal processors (PipelineSource, PipelineSink, Pipeline) in the report. + internal_names = ("Pipeline#", "PipelineTask#") + for t in report.processor_timings: + for prefix in internal_names: + self.assertNotIn( + prefix, + t.processor_name, + f"Internal processor {t.processor_name} should be excluded by default", + ) + + async def test_transport_timing_client_only(self): + """Test that ClientConnectedFrame emits on_transport_timing_report.""" + observer = StartupTimingObserver() + processor = FastProcessor() + + transport_reports = [] + + @observer.event_handler("on_transport_timing_report") + async def on_transport(obs, report): + transport_reports.append(report) + + frames_to_send = [ClientConnectedFrame(), TextFrame(text="hello")] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[ClientConnectedFrame, TextFrame], + observers=[observer], + ) + + self.assertEqual(len(transport_reports), 1) + report = transport_reports[0] + self.assertIsInstance(report, TransportTimingReport) + self.assertGreater(report.start_time, 0) + self.assertGreater(report.client_connected_secs, 0) + self.assertIsNone(report.bot_connected_secs) + + async def test_transport_timing_only_first_client(self): + """Test that only the first ClientConnectedFrame triggers the event.""" + observer = StartupTimingObserver() + processor = FastProcessor() + + transport_reports = [] + + @observer.event_handler("on_transport_timing_report") + async def on_transport(obs, report): + transport_reports.append(report) + + frames_to_send = [ + ClientConnectedFrame(), + ClientConnectedFrame(), + TextFrame(text="hello"), + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[ClientConnectedFrame, ClientConnectedFrame, TextFrame], + observers=[observer], + ) + + self.assertEqual(len(transport_reports), 1) + + async def test_transport_timing_without_start_frame(self): + """Test that ClientConnectedFrame before StartFrame does not crash.""" + observer = StartupTimingObserver() + + # Directly call on_push_frame with a ClientConnectedFrame before any + # StartFrame has been seen. This should be a no-op (no crash). + from pipecat.observers.base_observer import FramePushed + + processor = FastProcessor() + destination = FastProcessor() + data = FramePushed( + source=processor, + destination=destination, + frame=ClientConnectedFrame(), + direction=FrameDirection.DOWNSTREAM, + timestamp=1000, + ) + await observer.on_push_frame(data) + + # No event should have been emitted. + self.assertFalse(observer._transport_timing_reported) + + async def test_bot_and_client_connected(self): + """Test that BotConnectedFrame timing is included in the transport report.""" + observer = StartupTimingObserver() + processor = FastProcessor() + + transport_reports = [] + + @observer.event_handler("on_transport_timing_report") + async def on_transport(obs, report): + transport_reports.append(report) + + frames_to_send = [ + BotConnectedFrame(), + ClientConnectedFrame(), + TextFrame(text="hello"), + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[BotConnectedFrame, ClientConnectedFrame, TextFrame], + observers=[observer], + ) + + self.assertEqual(len(transport_reports), 1) + report = transport_reports[0] + self.assertGreater(report.client_connected_secs, 0) + self.assertIsNotNone(report.bot_connected_secs) + self.assertGreater(report.bot_connected_secs, 0) + + # Client connected should be >= bot connected. + self.assertGreaterEqual(report.client_connected_secs, report.bot_connected_secs) + + async def test_bot_connected_only_first(self): + """Test that only the first BotConnectedFrame is recorded.""" + observer = StartupTimingObserver() + processor = FastProcessor() + + transport_reports = [] + + @observer.event_handler("on_transport_timing_report") + async def on_transport(obs, report): + transport_reports.append(report) + + frames_to_send = [ + BotConnectedFrame(), + BotConnectedFrame(), + ClientConnectedFrame(), + TextFrame(text="hello"), + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=[ + BotConnectedFrame, + BotConnectedFrame, + ClientConnectedFrame, + TextFrame, + ], + observers=[observer], + ) + + # Only one transport report, with bot timing from first frame. + self.assertEqual(len(transport_reports), 1) + self.assertIsNotNone(transport_reports[0].bot_connected_secs) + + +if __name__ == "__main__": + unittest.main()