Merge branch 'main' into filipi/pipeline_freeze

# Conflicts:
#	CHANGELOG.md
#	src/pipecat/pipeline/task.py
#	src/pipecat/processors/frame_processor.py
#	src/pipecat/transports/base_input.py
This commit is contained in:
Filipi Fuchter
2025-06-24 17:11:21 -03:00
25 changed files with 400 additions and 116 deletions

View File

@@ -12,6 +12,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added logging and improved error handling to help diagnose and prevent potential
Pipeline freezes.
- Introduce task watchdog timers. Watchdog timers are used to detect if a
Pipecat task is taking longer than expected (by default 5 seconds). It is
possible to change the default watchdog timer timeout by using the
`watchdog_timeout` constructor argument when creating a `PipelineTask`. With
watchdog timers it is also possible to log how long each processing step is
taking (e.g. processing an element from a queue inside a task). This is done
with the `enable_watchdog_logging` constructor argument when creating a
`PipelineTask.` It is also possible to control these two values per each frame
processor. That is, you can set set `enable_watchdog_logging` and
`watchdog_timeout` when creating any frame processor through their constructor
arguments. Finally, you can also set these values per task. So, if you are
writing a frame processor that creates multiple tasks and you only want to
enable logging for one of them, you can do so by passing the same argument
names to the `FrameProcessor.create_task()` function. Note that watchdog
timers only work with Pipecat tasks but not if you use `asycio.create_task()`
or similar.
- Added `lexicon_names` parameter to `AWSPollyTTSService.InputParams`.
- Added reconnection logic and audio buffer management to `GladiaSTTService`.

View File

@@ -6,18 +6,21 @@
import asyncio
from abc import abstractmethod
from dataclasses import dataclass
from typing import AsyncIterable, Iterable
from pipecat.frames.frames import Frame
from pipecat.utils.base_object import BaseObject
class BaseTask(BaseObject):
@abstractmethod
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
"""Sets the event loop that this task will run on."""
pass
@dataclass
class PipelineTaskParams:
"""Specific configuration for the pipeline task."""
loop: asyncio.AbstractEventLoop
class BasePipelineTask(BaseObject):
@abstractmethod
def has_finished(self) -> bool:
"""Indicates whether the tasks has finished. That is, all processors
@@ -40,7 +43,7 @@ class BaseTask(BaseObject):
pass
@abstractmethod
async def run(self):
async def run(self, params: PipelineTaskParams):
"""Starts running the given pipeline."""
pass

View File

@@ -202,14 +202,18 @@ class ParallelPipeline(BasePipeline):
async def _process_up_queue(self):
while True:
frame = await self._up_queue.get()
self.start_watchdog()
await self._parallel_push_frame(frame, FrameDirection.UPSTREAM)
self._up_queue.task_done()
self.reset_watchdog()
async def _process_down_queue(self):
running = True
while running:
frame = await self._down_queue.get()
self.start_watchdog()
endframe_counter = self._endframe_counter.get(frame.id, 0)
# If we have a counter, decrement it.
@@ -224,3 +228,5 @@ class ParallelPipeline(BasePipeline):
running = not (endframe_counter == 0 and isinstance(frame, EndFrame))
self._down_queue.task_done()
self.reset_watchdog()

View File

@@ -11,6 +11,7 @@ from typing import Optional
from loguru import logger
from pipecat.pipeline.base_task import PipelineTaskParams
from pipecat.pipeline.task import PipelineTask
from pipecat.utils.base_object import BaseObject
@@ -37,8 +38,8 @@ class PipelineRunner(BaseObject):
async def run(self, task: PipelineTask):
logger.debug(f"Runner {self} started running {task}")
self._tasks[task.name] = task
task.set_event_loop(self._loop)
await task.run()
params = PipelineTaskParams(loop=self._loop)
await task.run(params)
del self._tasks[task.name]
# Cleanup base object.

View File

@@ -7,7 +7,7 @@
import asyncio
import time
from collections import deque
from typing import Any, AsyncIterable, Deque, Dict, Iterable, List, Optional, Sequence, Tuple, Type
from typing import Any, AsyncIterable, Deque, Dict, Iterable, List, Optional, Tuple, Type
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field
@@ -35,10 +35,10 @@ from pipecat.metrics.metrics import ProcessingMetricsData, TTFBMetricsData
from pipecat.observers.base_observer import BaseObserver
from pipecat.observers.turn_tracking_observer import TurnTrackingObserver
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.base_task import BaseTask
from pipecat.pipeline.base_task import BasePipelineTask, PipelineTaskParams
from pipecat.pipeline.task_observer import TaskObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
from pipecat.utils.asyncio import WATCHDOG_TIMEOUT, BaseTaskManager, TaskManager, TaskManagerParams
from pipecat.utils.tracing.setup import is_tracing_available
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
@@ -47,7 +47,10 @@ HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 10
class PipelineParams(BaseModel):
"""Configuration parameters for pipeline execution.
"""Configuration parameters for pipeline execution. These parameters are
usually passed to all frame processors using through `StartFrame`. For other
generic pipeline task parameters use `PipelineTask` constructor arguments
instead.
Attributes:
allow_interruptions: Whether to allow pipeline interruptions.
@@ -62,6 +65,7 @@ class PipelineParams(BaseModel):
send_initial_empty_metrics: Whether to send initial empty metrics.
start_metadata: Additional metadata for pipeline start.
interruption_strategies: Strategies for bot interruption behavior.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -73,11 +77,11 @@ class PipelineParams(BaseModel):
enable_metrics: bool = False
enable_usage_metrics: bool = False
heartbeats_period_secs: float = HEARTBEAT_SECONDS
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
observers: List[BaseObserver] = Field(default_factory=list)
report_only_initial_ttfb: bool = False
send_initial_empty_metrics: bool = True
start_metadata: Dict[str, Any] = Field(default_factory=dict)
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
class PipelineTaskSource(FrameProcessor):
@@ -127,7 +131,7 @@ class PipelineTaskSink(FrameProcessor):
await self._down_queue.put(frame)
class PipelineTask(BaseTask):
class PipelineTask(BasePipelineTask):
"""Manages the execution of a pipeline, handling frame processing and task lifecycle.
It has a couple of event handlers `on_frame_reached_upstream` and
@@ -174,21 +178,24 @@ class PipelineTask(BaseTask):
Args:
pipeline: The pipeline to execute.
params: Configuration parameters for the pipeline.
observers: List of observers for monitoring pipeline execution.
clock: Clock implementation for timing operations.
additional_span_attributes: Optional dictionary of attributes to propagate as
OpenTelemetry conversation span attributes.
cancel_on_idle_timeout: Whether the pipeline task should be cancelled if
the idle timeout is reached.
check_dangling_tasks: Whether to check for processors' tasks finishing properly.
clock: Clock implementation for timing operations.
conversation_id: Optional custom ID for the conversation.
enable_tracing: Whether to enable tracing.
enable_turn_tracking: Whether to enable turn tracking.
enable_watchdog_logging: Whether to print task processing times.
idle_timeout_frames: A tuple with the frames that should trigger an idle
timeout if not received withing `idle_timeout_seconds`.
idle_timeout_secs: Timeout (in seconds) to consider pipeline idle or
None. If a pipeline is idle the pipeline task will be cancelled
automatically.
idle_timeout_frames: A tuple with the frames that should trigger an idle
timeout if not received withing `idle_timeout_seconds`.
cancel_on_idle_timeout: Whether the pipeline task should be cancelled if
the idle timeout is reached.
enable_turn_tracking: Whether to enable turn tracking.
enable_turn_tracing: Whether to enable turn tracing.
conversation_id: Optional custom ID for the conversation.
additional_span_attributes: Optional dictionary of attributes to propagate as
OpenTelemetry conversation span attributes.
observers: List of observers for monitoring pipeline execution.
watchdog_timeout_secs: Watchdog timer timeout (in seconds). A warning
will be logged if the watchdog timer is not reset before this timeout.
"""
def __init__(
@@ -196,33 +203,37 @@ class PipelineTask(BaseTask):
pipeline: BasePipeline,
*,
params: Optional[PipelineParams] = None,
observers: Optional[List[BaseObserver]] = None,
clock: Optional[BaseClock] = None,
task_manager: Optional[BaseTaskManager] = None,
additional_span_attributes: Optional[dict] = None,
cancel_on_idle_timeout: bool = True,
check_dangling_tasks: bool = True,
idle_timeout_secs: Optional[float] = 300,
clock: Optional[BaseClock] = None,
conversation_id: Optional[str] = None,
enable_tracing: bool = False,
enable_turn_tracking: bool = True,
enable_watchdog_logging: bool = False,
idle_timeout_frames: Tuple[Type[Frame], ...] = (
BotSpeakingFrame,
LLMFullResponseEndFrame,
),
cancel_on_idle_timeout: bool = True,
enable_turn_tracking: bool = True,
enable_tracing: bool = False,
conversation_id: Optional[str] = None,
additional_span_attributes: Optional[dict] = None,
idle_timeout_secs: Optional[float] = 300,
observers: Optional[List[BaseObserver]] = None,
task_manager: Optional[BaseTaskManager] = None,
watchdog_timeout_secs: float = WATCHDOG_TIMEOUT,
):
super().__init__()
self._pipeline = pipeline
self._clock = clock or SystemClock()
self._params = params or PipelineParams()
self._check_dangling_tasks = check_dangling_tasks
self._idle_timeout_secs = idle_timeout_secs
self._idle_timeout_frames = idle_timeout_frames
self._cancel_on_idle_timeout = cancel_on_idle_timeout
self._enable_turn_tracking = enable_turn_tracking
self._enable_tracing = enable_tracing and is_tracing_available()
self._conversation_id = conversation_id
self._additional_span_attributes = additional_span_attributes or {}
self._cancel_on_idle_timeout = cancel_on_idle_timeout
self._check_dangling_tasks = check_dangling_tasks
self._clock = clock or SystemClock()
self._conversation_id = conversation_id
self._enable_tracing = enable_tracing and is_tracing_available()
self._enable_turn_tracking = enable_turn_tracking
self._enable_watchdog_logging = enable_watchdog_logging
self._idle_timeout_frames = idle_timeout_frames
self._idle_timeout_secs = idle_timeout_secs
self._watchdog_timeout_secs = watchdog_timeout_secs
if self._params.observers:
import warnings
@@ -324,9 +335,6 @@ class PipelineTask(BaseTask):
async def remove_observer(self, observer: BaseObserver):
await self._observer.remove_observer(observer)
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
self._task_manager.set_event_loop(loop)
def set_reached_upstream_filter(self, types: Tuple[Type[Frame], ...]):
"""Sets which frames will be checked before calling the
on_frame_reached_upstream event handler.
@@ -360,14 +368,14 @@ class PipelineTask(BaseTask):
"""Stops the running pipeline immediately."""
await self._cancel()
async def run(self):
async def run(self, params: PipelineTaskParams):
"""Starts and manages the pipeline execution until completion or cancellation."""
if self.has_finished():
return
cleanup_pipeline = True
try:
# Setup processors.
await self._setup()
await self._setup(params)
# Create all main tasks and wait of the main push task. This is the
# task that pushes frames to the very beginning of our pipeline (our
@@ -487,7 +495,14 @@ class PipelineTask(BaseTask):
await self._pipeline_end_event.wait()
self._pipeline_end_event.clear()
async def _setup(self):
async def _setup(self, params: PipelineTaskParams):
mgr_params = TaskManagerParams(
loop=params.loop,
enable_watchdog_logging=self._enable_watchdog_logging,
watchdog_timeout=self._watchdog_timeout_secs,
)
self._task_manager.setup(mgr_params)
setup = FrameProcessorSetup(
clock=self._clock,
task_manager=self._task_manager,
@@ -511,6 +526,8 @@ class PipelineTask(BaseTask):
await self._pipeline.cleanup()
await self._sink.cleanup()
await self._task_manager.cleanup()
async def _process_push_queue(self):
"""This is the task that runs the pipeline for the first time by sending
a StartFrame and by pushing any other frames queued by the user. It runs

View File

@@ -61,5 +61,7 @@ class ConsumerProcessor(FrameProcessor):
async def _consumer_task_handler(self):
while True:
frame = await self._queue.get()
self.start_watchdog()
new_frame = await self._transformer(frame)
await self.push_frame(new_frame, self._direction)
self.reset_watchdog()

View File

@@ -51,6 +51,8 @@ class FrameProcessor(BaseObject):
*,
name: Optional[str] = None,
metrics: Optional[FrameProcessorMetrics] = None,
enable_watchdog_logging: Optional[bool] = None,
watchdog_timeout: Optional[bool] = None,
**kwargs,
):
super().__init__(name=name)
@@ -58,6 +60,12 @@ class FrameProcessor(BaseObject):
self._prev: Optional["FrameProcessor"] = None
self._next: Optional["FrameProcessor"] = None
# Enable watchdog logging for all tasks created by this frame processor.
self._enable_watchdog_logging = enable_watchdog_logging
# Allow this frame processor to control their tasks timeout.
self._watchdog_timeout = watchdog_timeout
# Clock
self._clock: Optional[BaseClock] = None
@@ -171,24 +179,40 @@ class FrameProcessor(BaseObject):
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task:
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
def create_task(
self,
coroutine: Coroutine,
name: Optional[str] = None,
*,
enable_watchdog_logging: Optional[bool] = None,
watchdog_timeout: Optional[float] = None,
) -> asyncio.Task:
if name:
name = f"{self}::{name}"
else:
name = f"{self}::{coroutine.cr_code.co_name}"
return self._task_manager.create_task(coroutine, name)
return self.get_task_manager().create_task(
coroutine,
name,
enable_watchdog_logging=(
enable_watchdog_logging
if enable_watchdog_logging
else self._enable_watchdog_logging
),
watchdog_timeout=watchdog_timeout if watchdog_timeout else self._watchdog_timeout,
)
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None):
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
await self._task_manager.cancel_task(task, timeout)
await self.get_task_manager().cancel_task(task, timeout)
async def wait_for_task(self, task: asyncio.Task, timeout: Optional[float] = None):
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
await self._task_manager.wait_for_task(task, timeout)
await self.get_task_manager().wait_for_task(task, timeout)
def start_watchdog(self):
self.get_task_manager().start_watchdog(asyncio.current_task())
def reset_watchdog(self):
self.get_task_manager().reset_watchdog(asyncio.current_task())
async def setup(self, setup: FrameProcessorSetup):
self._clock = setup.clock
@@ -206,9 +230,7 @@ class FrameProcessor(BaseObject):
logger.debug(f"Linking {self} -> {self._next}")
def get_event_loop(self) -> asyncio.AbstractEventLoop:
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
return self._task_manager.get_event_loop()
return self.get_task_manager().get_event_loop()
def set_parent(self, parent: "FrameProcessor"):
self._parent = parent
@@ -388,6 +410,7 @@ class FrameProcessor(BaseObject):
(frame, direction, callback) = await self.__input_queue.get()
try:
self.start_watchdog()
# Process the frame.
await self.process_frame(frame, direction)
# If this frame has an associated callback, call it now.
@@ -398,6 +421,7 @@ class FrameProcessor(BaseObject):
await self.push_error(ErrorFrame(str(e)))
finally:
self.__input_queue.task_done()
self.reset_watchdog()
def __create_push_task(self):
if not self.__push_frame_task:
@@ -412,5 +436,7 @@ class FrameProcessor(BaseObject):
async def __push_frame_task_handler(self):
while True:
(frame, direction) = await self.__push_queue.get()
self.start_watchdog()
await self.__internal_push_frame(frame, direction)
self.__push_queue.task_done()
self.reset_watchdog()

View File

@@ -783,14 +783,18 @@ class RTVIProcessor(FrameProcessor):
async def _action_task_handler(self):
while True:
frame = await self._action_queue.get()
self.start_watchdog()
await self._handle_action(frame.message_id, frame.rtvi_action_run)
self._action_queue.task_done()
self.reset_watchdog()
async def _message_task_handler(self):
while True:
message = await self._message_queue.get()
self.start_watchdog()
await self._handle_message(message)
self._message_queue.task_done()
self.reset_watchdog()
async def _handle_transport_message(self, frame: TransportMessageUrgentFrame):
try:

View File

@@ -190,6 +190,7 @@ class AssemblyAISTTService(STTService):
while self._connected:
try:
message = await self._websocket.recv()
self.start_watchdog()
data = json.loads(message)
await self._handle_message(data)
except websockets.exceptions.ConnectionClosedOK:
@@ -197,6 +198,8 @@ class AssemblyAISTTService(STTService):
except Exception as e:
logger.error(f"Error processing WebSocket message: {e}")
break
finally:
self.reset_watchdog()
except Exception as e:
logger.error(f"Fatal error in receive handler: {e}")

View File

@@ -285,6 +285,9 @@ class AWSTranscribeSTTService(STTService):
try:
response = await self._ws_client.recv()
self.start_watchdog()
headers, payload = decode_event(response)
if headers.get(":message-type") == "event":
@@ -342,3 +345,5 @@ class AWSTranscribeSTTService(STTService):
except Exception as e:
logger.error(f"{self} Unexpected error in receive loop: {e}")
break
finally:
self.reset_watchdog()

View File

@@ -699,6 +699,8 @@ class AWSNovaSonicLLMService(LLMService):
output = await self._stream.await_output()
result = await output[1].receive()
self.start_watchdog()
if result.value and result.value.bytes_:
response_data = result.value.bytes_.decode("utf-8")
json_data = json.loads(response_data)
@@ -731,6 +733,8 @@ class AWSNovaSonicLLMService(LLMService):
logger.error(f"{self} error processing responses: {e}")
if self._wants_connection:
await self.reset_conversation()
finally:
self.reset_watchdog()
async def _handle_completion_start_event(self, event_json):
pass

View File

@@ -687,6 +687,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
async def _receive_task_handler(self):
async for message in self._websocket:
self.start_watchdog()
evt = events.parse_server_event(message)
# logger.debug(f"Received event: {message[:500]}")
# logger.debug(f"Received event: {evt}")
@@ -708,8 +710,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self._handle_evt_error(evt)
# errors are fatal, so exit the receive loop
return
else:
pass
self.reset_watchdog()
#
#

View File

@@ -502,6 +502,8 @@ class GladiaSTTService(STTService):
async def _receive_task_handler(self):
try:
async for message in self._websocket:
self.start_watchdog()
content = json.loads(message)
# Handle audio chunk acknowledgments
@@ -559,11 +561,15 @@ class GladiaSTTService(STTService):
translation, "", time_now_iso8601(), translated_language
)
)
self.reset_watchdog()
except websockets.exceptions.ConnectionClosed:
# Expected when closing the connection
pass
except Exception as e:
logger.error(f"Error in Gladia WebSocket handler: {e}")
finally:
self.reset_watchdog()
async def _maybe_reconnect(self) -> bool:
"""Handle exponential backoff reconnection logic."""

View File

@@ -747,9 +747,12 @@ class GoogleSTTService(STTService):
try:
while True:
try:
self.start_watchdog()
if self._request_queue.empty():
# wait for 10ms in case we don't have audio
await asyncio.sleep(0.01)
self.reset_watchdog()
continue
# Start bi-directional streaming
@@ -760,12 +763,13 @@ class GoogleSTTService(STTService):
# Process responses
await self._process_responses(streaming_recognize)
self.reset_watchdog()
# If we're here, check if we need to reconnect
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
logger.debug("Reconnecting stream after timeout")
# Reset stream start time
self._stream_start_time = int(time.time() * 1000)
continue
else:
# Normal stream end
break
@@ -775,7 +779,8 @@ class GoogleSTTService(STTService):
await asyncio.sleep(1) # Brief delay before reconnecting
self._stream_start_time = int(time.time() * 1000)
continue
finally:
self.reset_watchdog()
except Exception as e:
logger.error(f"Error in streaming task: {e}")
@@ -800,12 +805,16 @@ class GoogleSTTService(STTService):
"""Process streaming recognition responses."""
try:
async for response in streaming_recognize:
self.start_watchdog()
# Check streaming limit
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
logger.debug("Stream timeout reached in response processing")
self.reset_watchdog()
break
if not response.results:
self.reset_watchdog()
continue
for result in response.results:
@@ -848,8 +857,10 @@ class GoogleSTTService(STTService):
)
)
self.reset_watchdog()
except Exception as e:
logger.error(f"Error processing Google STT responses: {e}")
# Re-raise the exception to let it propagate (e.g. in the case of a timeout, propagate to _stream_audio to reconnect)
self.reset_watchdog()
# Re-raise the exception to let it propagate (e.g. in the case of a
# timeout, propagate to _stream_audio to reconnect)
raise

View File

@@ -203,12 +203,11 @@ class ResponseCancelEvent(ClientEvent):
class ServerEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
event_id: str
type: str
class Config:
arbitrary_types_allowed = True
class SessionCreatedEvent(ServerEvent):
type: Literal["session.created"]

View File

@@ -370,6 +370,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _receive_task_handler(self):
async for message in self._websocket:
self.start_watchdog()
evt = events.parse_server_event(message)
if evt.type == "session.created":
await self._handle_evt_session_created(evt)
@@ -400,6 +401,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
await self._handle_evt_error(evt)
# errors are fatal, so exit the receive loop
return
self.reset_watchdog()
@traced_openai_realtime(operation="llm_setup")
async def _handle_evt_session_created(self, evt):

View File

@@ -224,11 +224,13 @@ class RivaSTTService(STTService):
streaming_config=self._config,
)
for response in responses:
self.start_watchdog()
if not response.results:
continue
asyncio.run_coroutine_threadsafe(
self._response_queue.put(response), self.get_event_loop()
)
self.reset_watchdog()
async def _thread_task_handler(self):
try:
@@ -283,7 +285,9 @@ class RivaSTTService(STTService):
async def _response_task_handler(self):
while True:
response = await self._response_queue.get()
self.start_watchdog()
await self._handle_response(response)
self.reset_watchdog()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
await self.start_ttfb_metrics()

View File

@@ -62,6 +62,7 @@ class SimliVideoService(FrameProcessor):
async def _consume_and_process_audio(self):
await self._pipecat_resampler_event.wait()
async for audio_frame in self._simli_client.getAudioStreamIterator():
self.start_watchdog()
resampled_frames = self._pipecat_resampler.resample(audio_frame)
for resampled_frame in resampled_frames:
audio_array = resampled_frame.to_ndarray()
@@ -74,10 +75,12 @@ class SimliVideoService(FrameProcessor):
num_channels=1,
),
)
self.reset_watchdog()
async def _consume_and_process_video(self):
await self._pipecat_resampler_event.wait()
async for video_frame in self._simli_client.getVideoStreamIterator(targetFormat="rgb24"):
self.start_watchdog()
# Process the video frame
convertedFrame: OutputImageRawFrame = OutputImageRawFrame(
image=video_frame.to_rgb().to_image().tobytes(),
@@ -86,6 +89,7 @@ class SimliVideoService(FrameProcessor):
)
convertedFrame.pts = video_frame.pts
await self.push_frame(convertedFrame)
self.reset_watchdog()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

View File

@@ -217,5 +217,7 @@ class TavusVideoService(AIService):
async def _send_task_handler(self):
while True:
frame = await self._queue.get()
if isinstance(frame, OutputAudioRawFrame):
self.start_watchdog()
if isinstance(frame, OutputAudioRawFrame) and self._client:
await self._client.write_audio_frame(frame)
self.reset_watchdog()

View File

@@ -368,6 +368,8 @@ class BaseInputTransport(FrameProcessor):
self._audio_in_queue.get(), timeout=AUDIO_INPUT_TIMEOUT_SECS
)
self.start_watchdog()
# If an audio filter is available, run it before VAD.
if self._params.audio_in_filter:
frame.audio = await self._params.audio_in_filter.filter(frame.audio)
@@ -396,6 +398,8 @@ class BaseInputTransport(FrameProcessor):
self._params.turn_analyzer.clear()
await self._handle_user_interruption(UserStoppedSpeakingFrame())
self.reset_watchdog()
async def _handle_prediction_result(self, result: MetricsData):
"""Handle a prediction result event from the turn analyzer.

View File

@@ -182,6 +182,8 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
if not self._params.serializer:
continue
self.start_watchdog()
frame = await self._params.serializer.deserialize(message)
if not frame:
@@ -191,9 +193,13 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
await self.push_audio_frame(frame)
else:
await self.push_frame(frame)
self.reset_watchdog()
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
self.reset_watchdog()
await self._client.trigger_client_disconnected()
async def _monitor_websocket(self):

View File

@@ -423,8 +423,10 @@ class SmallWebRTCInputTransport(BaseInputTransport):
async def _receive_audio(self):
try:
async for audio_frame in self._client.read_audio_frame():
self.start_watchdog()
if audio_frame:
await self.push_audio_frame(audio_frame)
self.reset_watchdog()
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
@@ -432,6 +434,7 @@ class SmallWebRTCInputTransport(BaseInputTransport):
async def _receive_video(self):
try:
async for video_frame in self._client.read_video_frame():
self.start_watchdog()
if video_frame:
await self.push_video_frame(video_frame)
@@ -450,6 +453,7 @@ class SmallWebRTCInputTransport(BaseInputTransport):
await self.push_video_frame(image_frame)
# Remove from pending requests
del self._image_requests[req_id]
self.reset_watchdog()
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")

View File

@@ -415,6 +415,7 @@ class LiveKitInputTransport(BaseInputTransport):
logger.info("Audio input task started")
while True:
audio_data = await self._client.get_next_audio_frame()
self.start_watchdog()
if audio_data:
audio_frame_event, participant_id = audio_data
pipecat_audio_frame = await self._convert_livekit_audio_to_pipecat(
@@ -427,6 +428,7 @@ class LiveKitInputTransport(BaseInputTransport):
num_channels=pipecat_audio_frame.num_channels,
)
await self.push_audio_frame(input_audio_frame)
self.reset_watchdog()
async def _convert_livekit_audio_to_pipecat(
self, audio_frame_event: rtc.AudioFrameEvent

View File

@@ -5,15 +5,30 @@
#
import asyncio
import time
from abc import ABC, abstractmethod
from typing import Coroutine, Dict, Optional, Sequence, Set
from dataclasses import dataclass
from typing import Coroutine, Dict, List, Optional, Sequence
from loguru import logger
WATCHDOG_TIMEOUT = 5.0
@dataclass
class TaskManagerParams:
loop: asyncio.AbstractEventLoop
enable_watchdog_logging: bool = False
watchdog_timeout: float = WATCHDOG_TIMEOUT
class BaseTaskManager(ABC):
@abstractmethod
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
def setup(self, params: TaskManagerParams):
pass
@abstractmethod
async def cleanup(self):
pass
@abstractmethod
@@ -21,7 +36,14 @@ class BaseTaskManager(ABC):
pass
@abstractmethod
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
def create_task(
self,
coroutine: Coroutine,
name: str,
*,
enable_watchdog_logging: Optional[bool] = None,
watchdog_timeout: Optional[float] = None,
) -> asyncio.Task:
"""
Creates and schedules a new asyncio Task that runs the given coroutine.
@@ -31,6 +53,8 @@ class BaseTaskManager(ABC):
loop (asyncio.AbstractEventLoop): The event loop to use for creating the task.
coroutine (Coroutine): The coroutine to be executed within the task.
name (str): The name to assign to the task for identification.
enable_watchdog_logging(bool): whether this task should log watchdog processing times.
watchdog_timeout(float): watchdog timer timeout for this task.
Returns:
asyncio.Task: The created task object.
@@ -73,21 +97,64 @@ class BaseTaskManager(ABC):
"""Returns the list of currently created/registered tasks."""
pass
@abstractmethod
def start_watchdog(self, task: asyncio.Task):
"""Starts the given task watchdog timer. If not reset, a warning will be
logged indicating the task is stalling.
"""
pass
@abstractmethod
def reset_watchdog(self, task: asyncio.Task):
"""Resets the given task watchdog timer. If not reset, a warning will be
logged indicating the task is stalling.
"""
pass
@dataclass
class TaskData:
task: asyncio.Task
watchdog_start: asyncio.Event
watchdog_timer: asyncio.Event
enable_watchdog_logging: bool
watchdog_timeout: float
class TaskManager(BaseTaskManager):
def __init__(self) -> None:
self._tasks: Dict[str, asyncio.Task] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._tasks: Dict[str, TaskData] = {}
self._params: Optional[TaskManagerParams] = None
self._watchdog_tasks: List[asyncio.Task] = []
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
self._loop = loop
def setup(self, params: TaskManagerParams):
if not self._params:
self._params = params
async def cleanup(self):
for task in self._watchdog_tasks:
try:
task.cancel()
await task
except asyncio.CancelledError:
# This is expected, no need to re-raise.
pass
def get_event_loop(self) -> asyncio.AbstractEventLoop:
if not self._loop:
raise Exception("TaskManager missing event loop, use TaskManager.set_event_loop().")
return self._loop
if not self._params:
raise Exception("TaskManager is not setup: unable to get event loop")
return self._params.loop
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
def create_task(
self,
coroutine: Coroutine,
name: str,
*,
enable_watchdog_logging: Optional[bool] = None,
watchdog_timeout: Optional[float] = None,
) -> asyncio.Task:
"""
Creates and schedules a new asyncio Task that runs the given coroutine.
@@ -97,6 +164,8 @@ class TaskManager(BaseTaskManager):
loop (asyncio.AbstractEventLoop): The event loop to use for creating the task.
coroutine (Coroutine): The coroutine to be executed within the task.
name (str): The name to assign to the task for identification.
enable_watchdog_logging(bool): whether this task should log watchdog processing time.
watchdog_timeout(float): watchdog timer timeout for this task.
Returns:
asyncio.Task: The created task object.
@@ -112,12 +181,26 @@ class TaskManager(BaseTaskManager):
except Exception as e:
logger.exception(f"{name}: unexpected exception: {e}")
if not self._loop:
raise Exception("TaskManager missing event loop, use TaskManager.set_event_loop().")
if not self._params:
raise Exception("TaskManager is not setup: unable to get event loop")
task = self._loop.create_task(run_coroutine())
task = self._params.loop.create_task(run_coroutine())
task.set_name(name)
self._add_task(task)
self._add_task(
TaskData(
task=task,
watchdog_start=asyncio.Event(),
watchdog_timer=asyncio.Event(),
enable_watchdog_logging=(
enable_watchdog_logging
if enable_watchdog_logging
else self._params.enable_watchdog_logging
),
watchdog_timeout=(
watchdog_timeout if watchdog_timeout else self._params.watchdog_timeout
),
)
)
logger.trace(f"{name}: task created")
return task
@@ -165,6 +248,8 @@ class TaskManager(BaseTaskManager):
name = task.get_name()
task.cancel()
try:
# Make sure to reset watchdog if a task is cancelled.
self.reset_watchdog(task)
if timeout:
await asyncio.wait_for(task, timeout=timeout)
else:
@@ -184,11 +269,43 @@ class TaskManager(BaseTaskManager):
def current_tasks(self) -> Sequence[asyncio.Task]:
"""Returns the list of currently created/registered tasks."""
return list(self._tasks.values())
return [data.task for data in self._tasks.values()]
def _add_task(self, task: asyncio.Task):
def start_watchdog(self, task: asyncio.Task):
"""Starts the given task watchdog timer. If not reset, a warning will be
logged indicating the task is stalling. If the timer was already started
a warning will be logged.
"""
name = task.get_name()
self._tasks[name] = task
if name in self._tasks:
if self._tasks[name].watchdog_start.is_set():
logger.warning(f"Watchdog timer for task {name} already started")
else:
self._tasks[name].watchdog_timer.clear()
self._tasks[name].watchdog_start.set()
else:
logger.warning(f"Unable to start watchdog timer: task {name} does not exist")
def reset_watchdog(self, task: asyncio.Task):
"""Resets the given task watchdog timer. If not reset, a warning will be
logged indicating the task is stalling.
"""
name = task.get_name()
if name in self._tasks:
self._tasks[name].watchdog_start.clear()
self._tasks[name].watchdog_timer.set()
else:
logger.warning(f"Unable to reset watchdog timer: task {name} does not exist")
def _add_task(self, task_data: TaskData):
name = task_data.task.get_name()
self._tasks[name] = task_data
watchdog_task = self.get_event_loop().create_task(
self._watchdog_task_handler(self._tasks[name])
)
self._watchdog_tasks.append(watchdog_task)
def _remove_task(self, task: asyncio.Task):
name = task.get_name()
@@ -196,3 +313,33 @@ class TaskManager(BaseTaskManager):
del self._tasks[name]
except KeyError as e:
logger.trace(f"{name}: unable to remove task (already removed?): {e}")
async def _watchdog_task_handler(self, task_data: TaskData):
name = task_data.task.get_name()
start = task_data.watchdog_start
timer = task_data.watchdog_timer
enable_watchdog_logging = task_data.enable_watchdog_logging
watchdog_timeout = task_data.watchdog_timeout
async def wait_for_reset():
waiting = True
while waiting:
try:
start_time = time.time()
await asyncio.wait_for(timer.wait(), timeout=watchdog_timeout)
total_time = time.time() - start_time
if enable_watchdog_logging:
logger.debug(f"{name} task processing time: {total_time:.20f}")
waiting = False
except asyncio.TimeoutError:
logger.warning(
f"{name}: task is taking too long {WATCHDOG_TIMEOUT} second(s) (forgot to reset watchdog?)"
)
finally:
timer.clear()
while True:
# Wait for the user to start the watchdog timer.
await start.wait()
# Now, waiting for the task to finish.
await wait_for_reset()

View File

@@ -17,6 +17,7 @@ from pipecat.frames.frames import (
TextFrame,
)
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.pipeline.base_task import PipelineTaskParams
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -96,11 +97,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
async def test_task_single(self):
pipeline = Pipeline([IdentityFilter()])
task = PipelineTask(pipeline)
task.set_event_loop(asyncio.get_event_loop())
await task.queue_frame(TextFrame(text="Hello!"))
await task.queue_frames([TextFrame(text="Bye!"), EndFrame()])
await task.run()
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
assert task.has_finished()
async def test_task_observers(self):
@@ -116,10 +116,9 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline, observers=[CustomObserver()])
task.set_event_loop(asyncio.get_event_loop())
await task.queue_frames([TextFrame(text="Hello Downstream!"), EndFrame()])
await task.run()
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
assert frame_received
async def test_task_add_observer(self):
@@ -156,8 +155,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
observer1 = CustomAddObserver1()
task.add_observer(observer1)
task.set_event_loop(asyncio.get_event_loop())
async def delayed_add_observer():
observer2 = CustomAddObserver2()
# Wait after the pipeline is started and add another observer.
@@ -176,7 +173,9 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
# Finally end the pipeline.
await task.queue_frame(EndFrame())
await asyncio.gather(task.run(), delayed_add_observer())
await asyncio.gather(
task.run(PipelineTaskParams(loop=asyncio.get_event_loop())), delayed_add_observer()
)
assert frame_received
assert frame_count_1 == 1
@@ -189,7 +188,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline)
task.set_event_loop(asyncio.get_event_loop())
@task.event_handler("on_pipeline_started")
async def on_pipeline_started(task, frame: StartFrame):
@@ -202,7 +200,7 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
end_received = True
await task.queue_frame(EndFrame())
await task.run()
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
assert start_received
assert end_received
@@ -213,7 +211,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline)
task.set_event_loop(asyncio.get_event_loop())
@task.event_handler("on_pipeline_stopped")
async def on_pipeline_ended(task, frame: StopFrame):
@@ -221,7 +218,7 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
stop_received = True
await task.queue_frame(StopFrame())
await task.run()
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
assert stop_received
@@ -232,7 +229,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline, cancel_on_idle_timeout=False)
task.set_event_loop(asyncio.get_event_loop())
task.set_reached_upstream_filter((TextFrame,))
task.set_reached_downstream_filter((TextFrame,))
@@ -254,7 +250,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
await task.queue_frame(TextFrame(text="Hello Downstream!"))
try:
await asyncio.wait_for(asyncio.shield(task.run()), timeout=1.0)
await asyncio.wait_for(
asyncio.shield(task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))),
timeout=1.0,
)
except asyncio.TimeoutError:
pass
@@ -282,13 +281,15 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
observers=[heartbeats_observer],
cancel_on_idle_timeout=False,
)
task.set_event_loop(asyncio.get_event_loop())
expected_heartbeats = 1.0 / 0.2
await task.queue_frame(TextFrame(text="Hello!"))
try:
await asyncio.wait_for(asyncio.shield(task.run()), timeout=1.0)
await asyncio.wait_for(
asyncio.shield(task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))),
timeout=1.0,
)
except asyncio.TimeoutError:
pass
assert heartbeats_counter == expected_heartbeats
@@ -297,17 +298,18 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline, idle_timeout_secs=0.2)
task.set_event_loop(asyncio.get_event_loop())
await task.run()
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
assert True
async def test_no_idle_task(self):
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline, idle_timeout_secs=0.2, cancel_on_idle_timeout=False)
task.set_event_loop(asyncio.get_event_loop())
try:
await asyncio.wait_for(asyncio.shield(task.run()), timeout=0.3)
await asyncio.wait_for(
asyncio.shield(task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))),
timeout=0.3,
)
except asyncio.TimeoutError:
assert True
else:
@@ -324,15 +326,13 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
),
idle_timeout_secs=0.3,
)
task.set_event_loop(asyncio.get_event_loop())
await task.run()
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
assert True
async def test_idle_task_event_handler_no_frames(self):
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline, idle_timeout_secs=0.2, cancel_on_idle_timeout=False)
task.set_event_loop(asyncio.get_event_loop())
idle_timeout = False
@@ -342,14 +342,13 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
idle_timeout = True
await task.cancel()
await task.run()
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
assert idle_timeout
async def test_idle_task_event_handler_quiet_user(self):
identity = IdentityFilter()
pipeline = Pipeline([identity])
task = PipelineTask(pipeline, idle_timeout_secs=0.2, cancel_on_idle_timeout=False)
task.set_event_loop(asyncio.get_event_loop())
idle_timeout = 0
@@ -373,7 +372,9 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
)
await asyncio.sleep(0.01)
await asyncio.gather(send_audio(), task.run())
await asyncio.gather(
send_audio(), task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
)
assert idle_timeout == 1
async def test_idle_task_frames(self):
@@ -387,7 +388,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
idle_timeout_secs=idle_timeout_secs,
idle_timeout_frames=(TextFrame,),
)
task.set_event_loop(asyncio.get_event_loop())
async def delayed_frames():
await asyncio.sleep(sleep_time_secs)
@@ -399,7 +399,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
start_time = time.time()
tasks = {asyncio.create_task(task.run()), asyncio.create_task(delayed_frames())}
tasks = [
asyncio.create_task(task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))),
asyncio.create_task(delayed_frames()),
]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)