Merge branch 'main' into filipi/pipeline_freeze
# Conflicts: # CHANGELOG.md # src/pipecat/pipeline/task.py # src/pipecat/processors/frame_processor.py # src/pipecat/transports/base_input.py
This commit is contained in:
17
CHANGELOG.md
17
CHANGELOG.md
@@ -12,6 +12,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added logging and improved error handling to help diagnose and prevent potential
|
||||
Pipeline freezes.
|
||||
|
||||
- Introduce task watchdog timers. Watchdog timers are used to detect if a
|
||||
Pipecat task is taking longer than expected (by default 5 seconds). It is
|
||||
possible to change the default watchdog timer timeout by using the
|
||||
`watchdog_timeout` constructor argument when creating a `PipelineTask`. With
|
||||
watchdog timers it is also possible to log how long each processing step is
|
||||
taking (e.g. processing an element from a queue inside a task). This is done
|
||||
with the `enable_watchdog_logging` constructor argument when creating a
|
||||
`PipelineTask.` It is also possible to control these two values per each frame
|
||||
processor. That is, you can set set `enable_watchdog_logging` and
|
||||
`watchdog_timeout` when creating any frame processor through their constructor
|
||||
arguments. Finally, you can also set these values per task. So, if you are
|
||||
writing a frame processor that creates multiple tasks and you only want to
|
||||
enable logging for one of them, you can do so by passing the same argument
|
||||
names to the `FrameProcessor.create_task()` function. Note that watchdog
|
||||
timers only work with Pipecat tasks but not if you use `asycio.create_task()`
|
||||
or similar.
|
||||
|
||||
- Added `lexicon_names` parameter to `AWSPollyTTSService.InputParams`.
|
||||
|
||||
- Added reconnection logic and audio buffer management to `GladiaSTTService`.
|
||||
|
||||
@@ -6,18 +6,21 @@
|
||||
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
|
||||
class BaseTask(BaseObject):
|
||||
@abstractmethod
|
||||
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
"""Sets the event loop that this task will run on."""
|
||||
pass
|
||||
@dataclass
|
||||
class PipelineTaskParams:
|
||||
"""Specific configuration for the pipeline task."""
|
||||
|
||||
loop: asyncio.AbstractEventLoop
|
||||
|
||||
|
||||
class BasePipelineTask(BaseObject):
|
||||
@abstractmethod
|
||||
def has_finished(self) -> bool:
|
||||
"""Indicates whether the tasks has finished. That is, all processors
|
||||
@@ -40,7 +43,7 @@ class BaseTask(BaseObject):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run(self):
|
||||
async def run(self, params: PipelineTaskParams):
|
||||
"""Starts running the given pipeline."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -202,14 +202,18 @@ class ParallelPipeline(BasePipeline):
|
||||
async def _process_up_queue(self):
|
||||
while True:
|
||||
frame = await self._up_queue.get()
|
||||
self.start_watchdog()
|
||||
await self._parallel_push_frame(frame, FrameDirection.UPSTREAM)
|
||||
self._up_queue.task_done()
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _process_down_queue(self):
|
||||
running = True
|
||||
while running:
|
||||
frame = await self._down_queue.get()
|
||||
|
||||
self.start_watchdog()
|
||||
|
||||
endframe_counter = self._endframe_counter.get(frame.id, 0)
|
||||
|
||||
# If we have a counter, decrement it.
|
||||
@@ -224,3 +228,5 @@ class ParallelPipeline(BasePipeline):
|
||||
running = not (endframe_counter == 0 and isinstance(frame, EndFrame))
|
||||
|
||||
self._down_queue.task_done()
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
@@ -37,8 +38,8 @@ class PipelineRunner(BaseObject):
|
||||
async def run(self, task: PipelineTask):
|
||||
logger.debug(f"Runner {self} started running {task}")
|
||||
self._tasks[task.name] = task
|
||||
task.set_event_loop(self._loop)
|
||||
await task.run()
|
||||
params = PipelineTaskParams(loop=self._loop)
|
||||
await task.run(params)
|
||||
del self._tasks[task.name]
|
||||
|
||||
# Cleanup base object.
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, AsyncIterable, Deque, Dict, Iterable, List, Optional, Sequence, Tuple, Type
|
||||
from typing import Any, AsyncIterable, Deque, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -35,10 +35,10 @@ from pipecat.metrics.metrics import ProcessingMetricsData, TTFBMetricsData
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.observers.turn_tracking_observer import TurnTrackingObserver
|
||||
from pipecat.pipeline.base_pipeline import BasePipeline
|
||||
from pipecat.pipeline.base_task import BaseTask
|
||||
from pipecat.pipeline.base_task import BasePipelineTask, PipelineTaskParams
|
||||
from pipecat.pipeline.task_observer import TaskObserver
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
|
||||
from pipecat.utils.asyncio import WATCHDOG_TIMEOUT, BaseTaskManager, TaskManager, TaskManagerParams
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
|
||||
|
||||
@@ -47,7 +47,10 @@ HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 10
|
||||
|
||||
|
||||
class PipelineParams(BaseModel):
|
||||
"""Configuration parameters for pipeline execution.
|
||||
"""Configuration parameters for pipeline execution. These parameters are
|
||||
usually passed to all frame processors using through `StartFrame`. For other
|
||||
generic pipeline task parameters use `PipelineTask` constructor arguments
|
||||
instead.
|
||||
|
||||
Attributes:
|
||||
allow_interruptions: Whether to allow pipeline interruptions.
|
||||
@@ -62,6 +65,7 @@ class PipelineParams(BaseModel):
|
||||
send_initial_empty_metrics: Whether to send initial empty metrics.
|
||||
start_metadata: Additional metadata for pipeline start.
|
||||
interruption_strategies: Strategies for bot interruption behavior.
|
||||
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -73,11 +77,11 @@ class PipelineParams(BaseModel):
|
||||
enable_metrics: bool = False
|
||||
enable_usage_metrics: bool = False
|
||||
heartbeats_period_secs: float = HEARTBEAT_SECONDS
|
||||
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
|
||||
observers: List[BaseObserver] = Field(default_factory=list)
|
||||
report_only_initial_ttfb: bool = False
|
||||
send_initial_empty_metrics: bool = True
|
||||
start_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PipelineTaskSource(FrameProcessor):
|
||||
@@ -127,7 +131,7 @@ class PipelineTaskSink(FrameProcessor):
|
||||
await self._down_queue.put(frame)
|
||||
|
||||
|
||||
class PipelineTask(BaseTask):
|
||||
class PipelineTask(BasePipelineTask):
|
||||
"""Manages the execution of a pipeline, handling frame processing and task lifecycle.
|
||||
|
||||
It has a couple of event handlers `on_frame_reached_upstream` and
|
||||
@@ -174,21 +178,24 @@ class PipelineTask(BaseTask):
|
||||
Args:
|
||||
pipeline: The pipeline to execute.
|
||||
params: Configuration parameters for the pipeline.
|
||||
observers: List of observers for monitoring pipeline execution.
|
||||
clock: Clock implementation for timing operations.
|
||||
additional_span_attributes: Optional dictionary of attributes to propagate as
|
||||
OpenTelemetry conversation span attributes.
|
||||
cancel_on_idle_timeout: Whether the pipeline task should be cancelled if
|
||||
the idle timeout is reached.
|
||||
check_dangling_tasks: Whether to check for processors' tasks finishing properly.
|
||||
clock: Clock implementation for timing operations.
|
||||
conversation_id: Optional custom ID for the conversation.
|
||||
enable_tracing: Whether to enable tracing.
|
||||
enable_turn_tracking: Whether to enable turn tracking.
|
||||
enable_watchdog_logging: Whether to print task processing times.
|
||||
idle_timeout_frames: A tuple with the frames that should trigger an idle
|
||||
timeout if not received withing `idle_timeout_seconds`.
|
||||
idle_timeout_secs: Timeout (in seconds) to consider pipeline idle or
|
||||
None. If a pipeline is idle the pipeline task will be cancelled
|
||||
automatically.
|
||||
idle_timeout_frames: A tuple with the frames that should trigger an idle
|
||||
timeout if not received withing `idle_timeout_seconds`.
|
||||
cancel_on_idle_timeout: Whether the pipeline task should be cancelled if
|
||||
the idle timeout is reached.
|
||||
enable_turn_tracking: Whether to enable turn tracking.
|
||||
enable_turn_tracing: Whether to enable turn tracing.
|
||||
conversation_id: Optional custom ID for the conversation.
|
||||
additional_span_attributes: Optional dictionary of attributes to propagate as
|
||||
OpenTelemetry conversation span attributes.
|
||||
observers: List of observers for monitoring pipeline execution.
|
||||
watchdog_timeout_secs: Watchdog timer timeout (in seconds). A warning
|
||||
will be logged if the watchdog timer is not reset before this timeout.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -196,33 +203,37 @@ class PipelineTask(BaseTask):
|
||||
pipeline: BasePipeline,
|
||||
*,
|
||||
params: Optional[PipelineParams] = None,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
clock: Optional[BaseClock] = None,
|
||||
task_manager: Optional[BaseTaskManager] = None,
|
||||
additional_span_attributes: Optional[dict] = None,
|
||||
cancel_on_idle_timeout: bool = True,
|
||||
check_dangling_tasks: bool = True,
|
||||
idle_timeout_secs: Optional[float] = 300,
|
||||
clock: Optional[BaseClock] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
enable_tracing: bool = False,
|
||||
enable_turn_tracking: bool = True,
|
||||
enable_watchdog_logging: bool = False,
|
||||
idle_timeout_frames: Tuple[Type[Frame], ...] = (
|
||||
BotSpeakingFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
),
|
||||
cancel_on_idle_timeout: bool = True,
|
||||
enable_turn_tracking: bool = True,
|
||||
enable_tracing: bool = False,
|
||||
conversation_id: Optional[str] = None,
|
||||
additional_span_attributes: Optional[dict] = None,
|
||||
idle_timeout_secs: Optional[float] = 300,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
task_manager: Optional[BaseTaskManager] = None,
|
||||
watchdog_timeout_secs: float = WATCHDOG_TIMEOUT,
|
||||
):
|
||||
super().__init__()
|
||||
self._pipeline = pipeline
|
||||
self._clock = clock or SystemClock()
|
||||
self._params = params or PipelineParams()
|
||||
self._check_dangling_tasks = check_dangling_tasks
|
||||
self._idle_timeout_secs = idle_timeout_secs
|
||||
self._idle_timeout_frames = idle_timeout_frames
|
||||
self._cancel_on_idle_timeout = cancel_on_idle_timeout
|
||||
self._enable_turn_tracking = enable_turn_tracking
|
||||
self._enable_tracing = enable_tracing and is_tracing_available()
|
||||
self._conversation_id = conversation_id
|
||||
self._additional_span_attributes = additional_span_attributes or {}
|
||||
self._cancel_on_idle_timeout = cancel_on_idle_timeout
|
||||
self._check_dangling_tasks = check_dangling_tasks
|
||||
self._clock = clock or SystemClock()
|
||||
self._conversation_id = conversation_id
|
||||
self._enable_tracing = enable_tracing and is_tracing_available()
|
||||
self._enable_turn_tracking = enable_turn_tracking
|
||||
self._enable_watchdog_logging = enable_watchdog_logging
|
||||
self._idle_timeout_frames = idle_timeout_frames
|
||||
self._idle_timeout_secs = idle_timeout_secs
|
||||
self._watchdog_timeout_secs = watchdog_timeout_secs
|
||||
if self._params.observers:
|
||||
import warnings
|
||||
|
||||
@@ -324,9 +335,6 @@ class PipelineTask(BaseTask):
|
||||
async def remove_observer(self, observer: BaseObserver):
|
||||
await self._observer.remove_observer(observer)
|
||||
|
||||
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
self._task_manager.set_event_loop(loop)
|
||||
|
||||
def set_reached_upstream_filter(self, types: Tuple[Type[Frame], ...]):
|
||||
"""Sets which frames will be checked before calling the
|
||||
on_frame_reached_upstream event handler.
|
||||
@@ -360,14 +368,14 @@ class PipelineTask(BaseTask):
|
||||
"""Stops the running pipeline immediately."""
|
||||
await self._cancel()
|
||||
|
||||
async def run(self):
|
||||
async def run(self, params: PipelineTaskParams):
|
||||
"""Starts and manages the pipeline execution until completion or cancellation."""
|
||||
if self.has_finished():
|
||||
return
|
||||
cleanup_pipeline = True
|
||||
try:
|
||||
# Setup processors.
|
||||
await self._setup()
|
||||
await self._setup(params)
|
||||
|
||||
# Create all main tasks and wait of the main push task. This is the
|
||||
# task that pushes frames to the very beginning of our pipeline (our
|
||||
@@ -487,7 +495,14 @@ class PipelineTask(BaseTask):
|
||||
await self._pipeline_end_event.wait()
|
||||
self._pipeline_end_event.clear()
|
||||
|
||||
async def _setup(self):
|
||||
async def _setup(self, params: PipelineTaskParams):
|
||||
mgr_params = TaskManagerParams(
|
||||
loop=params.loop,
|
||||
enable_watchdog_logging=self._enable_watchdog_logging,
|
||||
watchdog_timeout=self._watchdog_timeout_secs,
|
||||
)
|
||||
self._task_manager.setup(mgr_params)
|
||||
|
||||
setup = FrameProcessorSetup(
|
||||
clock=self._clock,
|
||||
task_manager=self._task_manager,
|
||||
@@ -511,6 +526,8 @@ class PipelineTask(BaseTask):
|
||||
await self._pipeline.cleanup()
|
||||
await self._sink.cleanup()
|
||||
|
||||
await self._task_manager.cleanup()
|
||||
|
||||
async def _process_push_queue(self):
|
||||
"""This is the task that runs the pipeline for the first time by sending
|
||||
a StartFrame and by pushing any other frames queued by the user. It runs
|
||||
|
||||
@@ -61,5 +61,7 @@ class ConsumerProcessor(FrameProcessor):
|
||||
async def _consumer_task_handler(self):
|
||||
while True:
|
||||
frame = await self._queue.get()
|
||||
self.start_watchdog()
|
||||
new_frame = await self._transformer(frame)
|
||||
await self.push_frame(new_frame, self._direction)
|
||||
self.reset_watchdog()
|
||||
|
||||
@@ -51,6 +51,8 @@ class FrameProcessor(BaseObject):
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
metrics: Optional[FrameProcessorMetrics] = None,
|
||||
enable_watchdog_logging: Optional[bool] = None,
|
||||
watchdog_timeout: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name=name)
|
||||
@@ -58,6 +60,12 @@ class FrameProcessor(BaseObject):
|
||||
self._prev: Optional["FrameProcessor"] = None
|
||||
self._next: Optional["FrameProcessor"] = None
|
||||
|
||||
# Enable watchdog logging for all tasks created by this frame processor.
|
||||
self._enable_watchdog_logging = enable_watchdog_logging
|
||||
|
||||
# Allow this frame processor to control their tasks timeout.
|
||||
self._watchdog_timeout = watchdog_timeout
|
||||
|
||||
# Clock
|
||||
self._clock: Optional[BaseClock] = None
|
||||
|
||||
@@ -171,24 +179,40 @@ class FrameProcessor(BaseObject):
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task:
|
||||
if not self._task_manager:
|
||||
raise Exception(f"{self} TaskManager is still not initialized.")
|
||||
def create_task(
|
||||
self,
|
||||
coroutine: Coroutine,
|
||||
name: Optional[str] = None,
|
||||
*,
|
||||
enable_watchdog_logging: Optional[bool] = None,
|
||||
watchdog_timeout: Optional[float] = None,
|
||||
) -> asyncio.Task:
|
||||
if name:
|
||||
name = f"{self}::{name}"
|
||||
else:
|
||||
name = f"{self}::{coroutine.cr_code.co_name}"
|
||||
return self._task_manager.create_task(coroutine, name)
|
||||
return self.get_task_manager().create_task(
|
||||
coroutine,
|
||||
name,
|
||||
enable_watchdog_logging=(
|
||||
enable_watchdog_logging
|
||||
if enable_watchdog_logging
|
||||
else self._enable_watchdog_logging
|
||||
),
|
||||
watchdog_timeout=watchdog_timeout if watchdog_timeout else self._watchdog_timeout,
|
||||
)
|
||||
|
||||
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None):
|
||||
if not self._task_manager:
|
||||
raise Exception(f"{self} TaskManager is still not initialized.")
|
||||
await self._task_manager.cancel_task(task, timeout)
|
||||
await self.get_task_manager().cancel_task(task, timeout)
|
||||
|
||||
async def wait_for_task(self, task: asyncio.Task, timeout: Optional[float] = None):
|
||||
if not self._task_manager:
|
||||
raise Exception(f"{self} TaskManager is still not initialized.")
|
||||
await self._task_manager.wait_for_task(task, timeout)
|
||||
await self.get_task_manager().wait_for_task(task, timeout)
|
||||
|
||||
def start_watchdog(self):
|
||||
self.get_task_manager().start_watchdog(asyncio.current_task())
|
||||
|
||||
def reset_watchdog(self):
|
||||
self.get_task_manager().reset_watchdog(asyncio.current_task())
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
self._clock = setup.clock
|
||||
@@ -206,9 +230,7 @@ class FrameProcessor(BaseObject):
|
||||
logger.debug(f"Linking {self} -> {self._next}")
|
||||
|
||||
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
if not self._task_manager:
|
||||
raise Exception(f"{self} TaskManager is still not initialized.")
|
||||
return self._task_manager.get_event_loop()
|
||||
return self.get_task_manager().get_event_loop()
|
||||
|
||||
def set_parent(self, parent: "FrameProcessor"):
|
||||
self._parent = parent
|
||||
@@ -388,6 +410,7 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
(frame, direction, callback) = await self.__input_queue.get()
|
||||
try:
|
||||
self.start_watchdog()
|
||||
# Process the frame.
|
||||
await self.process_frame(frame, direction)
|
||||
# If this frame has an associated callback, call it now.
|
||||
@@ -398,6 +421,7 @@ class FrameProcessor(BaseObject):
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
finally:
|
||||
self.__input_queue.task_done()
|
||||
self.reset_watchdog()
|
||||
|
||||
def __create_push_task(self):
|
||||
if not self.__push_frame_task:
|
||||
@@ -412,5 +436,7 @@ class FrameProcessor(BaseObject):
|
||||
async def __push_frame_task_handler(self):
|
||||
while True:
|
||||
(frame, direction) = await self.__push_queue.get()
|
||||
self.start_watchdog()
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
self.__push_queue.task_done()
|
||||
self.reset_watchdog()
|
||||
|
||||
@@ -783,14 +783,18 @@ class RTVIProcessor(FrameProcessor):
|
||||
async def _action_task_handler(self):
|
||||
while True:
|
||||
frame = await self._action_queue.get()
|
||||
self.start_watchdog()
|
||||
await self._handle_action(frame.message_id, frame.rtvi_action_run)
|
||||
self._action_queue.task_done()
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _message_task_handler(self):
|
||||
while True:
|
||||
message = await self._message_queue.get()
|
||||
self.start_watchdog()
|
||||
await self._handle_message(message)
|
||||
self._message_queue.task_done()
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _handle_transport_message(self, frame: TransportMessageUrgentFrame):
|
||||
try:
|
||||
|
||||
@@ -190,6 +190,7 @@ class AssemblyAISTTService(STTService):
|
||||
while self._connected:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
self.start_watchdog()
|
||||
data = json.loads(message)
|
||||
await self._handle_message(data)
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
@@ -197,6 +198,8 @@ class AssemblyAISTTService(STTService):
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing WebSocket message: {e}")
|
||||
break
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in receive handler: {e}")
|
||||
|
||||
@@ -285,6 +285,9 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
try:
|
||||
response = await self._ws_client.recv()
|
||||
|
||||
self.start_watchdog()
|
||||
|
||||
headers, payload = decode_event(response)
|
||||
|
||||
if headers.get(":message-type") == "event":
|
||||
@@ -342,3 +345,5 @@ class AWSTranscribeSTTService(STTService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
break
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
|
||||
@@ -699,6 +699,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
output = await self._stream.await_output()
|
||||
result = await output[1].receive()
|
||||
|
||||
self.start_watchdog()
|
||||
|
||||
if result.value and result.value.bytes_:
|
||||
response_data = result.value.bytes_.decode("utf-8")
|
||||
json_data = json.loads(response_data)
|
||||
@@ -731,6 +733,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
logger.error(f"{self} error processing responses: {e}")
|
||||
if self._wants_connection:
|
||||
await self.reset_conversation()
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _handle_completion_start_event(self, event_json):
|
||||
pass
|
||||
|
||||
@@ -687,6 +687,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
async for message in self._websocket:
|
||||
self.start_watchdog()
|
||||
|
||||
evt = events.parse_server_event(message)
|
||||
# logger.debug(f"Received event: {message[:500]}")
|
||||
# logger.debug(f"Received event: {evt}")
|
||||
@@ -708,8 +710,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
else:
|
||||
pass
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
#
|
||||
#
|
||||
|
||||
@@ -502,6 +502,8 @@ class GladiaSTTService(STTService):
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
self.start_watchdog()
|
||||
|
||||
content = json.loads(message)
|
||||
|
||||
# Handle audio chunk acknowledgments
|
||||
@@ -559,11 +561,15 @@ class GladiaSTTService(STTService):
|
||||
translation, "", time_now_iso8601(), translated_language
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_watchdog()
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
# Expected when closing the connection
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gladia WebSocket handler: {e}")
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _maybe_reconnect(self) -> bool:
|
||||
"""Handle exponential backoff reconnection logic."""
|
||||
|
||||
@@ -747,9 +747,12 @@ class GoogleSTTService(STTService):
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
self.start_watchdog()
|
||||
|
||||
if self._request_queue.empty():
|
||||
# wait for 10ms in case we don't have audio
|
||||
await asyncio.sleep(0.01)
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
# Start bi-directional streaming
|
||||
@@ -760,12 +763,13 @@ class GoogleSTTService(STTService):
|
||||
# Process responses
|
||||
await self._process_responses(streaming_recognize)
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
# If we're here, check if we need to reconnect
|
||||
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
|
||||
logger.debug("Reconnecting stream after timeout")
|
||||
# Reset stream start time
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
continue
|
||||
else:
|
||||
# Normal stream end
|
||||
break
|
||||
@@ -775,7 +779,8 @@ class GoogleSTTService(STTService):
|
||||
|
||||
await asyncio.sleep(1) # Brief delay before reconnecting
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
continue
|
||||
finally:
|
||||
self.reset_watchdog()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming task: {e}")
|
||||
@@ -800,12 +805,16 @@ class GoogleSTTService(STTService):
|
||||
"""Process streaming recognition responses."""
|
||||
try:
|
||||
async for response in streaming_recognize:
|
||||
self.start_watchdog()
|
||||
|
||||
# Check streaming limit
|
||||
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
|
||||
logger.debug("Stream timeout reached in response processing")
|
||||
self.reset_watchdog()
|
||||
break
|
||||
|
||||
if not response.results:
|
||||
self.reset_watchdog()
|
||||
continue
|
||||
|
||||
for result in response.results:
|
||||
@@ -848,8 +857,10 @@ class GoogleSTTService(STTService):
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_watchdog()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Google STT responses: {e}")
|
||||
|
||||
# Re-raise the exception to let it propagate (e.g. in the case of a timeout, propagate to _stream_audio to reconnect)
|
||||
self.reset_watchdog()
|
||||
# Re-raise the exception to let it propagate (e.g. in the case of a
|
||||
# timeout, propagate to _stream_audio to reconnect)
|
||||
raise
|
||||
|
||||
@@ -203,12 +203,11 @@ class ResponseCancelEvent(ClientEvent):
|
||||
|
||||
|
||||
class ServerEvent(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
event_id: str
|
||||
type: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class SessionCreatedEvent(ServerEvent):
|
||||
type: Literal["session.created"]
|
||||
|
||||
@@ -370,6 +370,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
async for message in self._websocket:
|
||||
self.start_watchdog()
|
||||
evt = events.parse_server_event(message)
|
||||
if evt.type == "session.created":
|
||||
await self._handle_evt_session_created(evt)
|
||||
@@ -400,6 +401,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
self.reset_watchdog()
|
||||
|
||||
@traced_openai_realtime(operation="llm_setup")
|
||||
async def _handle_evt_session_created(self, evt):
|
||||
|
||||
@@ -224,11 +224,13 @@ class RivaSTTService(STTService):
|
||||
streaming_config=self._config,
|
||||
)
|
||||
for response in responses:
|
||||
self.start_watchdog()
|
||||
if not response.results:
|
||||
continue
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(response), self.get_event_loop()
|
||||
)
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _thread_task_handler(self):
|
||||
try:
|
||||
@@ -283,7 +285,9 @@ class RivaSTTService(STTService):
|
||||
async def _response_task_handler(self):
|
||||
while True:
|
||||
response = await self._response_queue.get()
|
||||
self.start_watchdog()
|
||||
await self._handle_response(response)
|
||||
self.reset_watchdog()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -62,6 +62,7 @@ class SimliVideoService(FrameProcessor):
|
||||
async def _consume_and_process_audio(self):
|
||||
await self._pipecat_resampler_event.wait()
|
||||
async for audio_frame in self._simli_client.getAudioStreamIterator():
|
||||
self.start_watchdog()
|
||||
resampled_frames = self._pipecat_resampler.resample(audio_frame)
|
||||
for resampled_frame in resampled_frames:
|
||||
audio_array = resampled_frame.to_ndarray()
|
||||
@@ -74,10 +75,12 @@ class SimliVideoService(FrameProcessor):
|
||||
num_channels=1,
|
||||
),
|
||||
)
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _consume_and_process_video(self):
|
||||
await self._pipecat_resampler_event.wait()
|
||||
async for video_frame in self._simli_client.getVideoStreamIterator(targetFormat="rgb24"):
|
||||
self.start_watchdog()
|
||||
# Process the video frame
|
||||
convertedFrame: OutputImageRawFrame = OutputImageRawFrame(
|
||||
image=video_frame.to_rgb().to_image().tobytes(),
|
||||
@@ -86,6 +89,7 @@ class SimliVideoService(FrameProcessor):
|
||||
)
|
||||
convertedFrame.pts = video_frame.pts
|
||||
await self.push_frame(convertedFrame)
|
||||
self.reset_watchdog()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -217,5 +217,7 @@ class TavusVideoService(AIService):
|
||||
async def _send_task_handler(self):
|
||||
while True:
|
||||
frame = await self._queue.get()
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
self.start_watchdog()
|
||||
if isinstance(frame, OutputAudioRawFrame) and self._client:
|
||||
await self._client.write_audio_frame(frame)
|
||||
self.reset_watchdog()
|
||||
|
||||
@@ -368,6 +368,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._audio_in_queue.get(), timeout=AUDIO_INPUT_TIMEOUT_SECS
|
||||
)
|
||||
|
||||
self.start_watchdog()
|
||||
|
||||
# If an audio filter is available, run it before VAD.
|
||||
if self._params.audio_in_filter:
|
||||
frame.audio = await self._params.audio_in_filter.filter(frame.audio)
|
||||
@@ -396,6 +398,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._params.turn_analyzer.clear()
|
||||
await self._handle_user_interruption(UserStoppedSpeakingFrame())
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _handle_prediction_result(self, result: MetricsData):
|
||||
"""Handle a prediction result event from the turn analyzer.
|
||||
|
||||
|
||||
@@ -182,6 +182,8 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
if not self._params.serializer:
|
||||
continue
|
||||
|
||||
self.start_watchdog()
|
||||
|
||||
frame = await self._params.serializer.deserialize(message)
|
||||
|
||||
if not frame:
|
||||
@@ -191,9 +193,13 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
await self.push_audio_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame)
|
||||
|
||||
self.reset_watchdog()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
self.reset_watchdog()
|
||||
|
||||
await self._client.trigger_client_disconnected()
|
||||
|
||||
async def _monitor_websocket(self):
|
||||
|
||||
@@ -423,8 +423,10 @@ class SmallWebRTCInputTransport(BaseInputTransport):
|
||||
async def _receive_audio(self):
|
||||
try:
|
||||
async for audio_frame in self._client.read_audio_frame():
|
||||
self.start_watchdog()
|
||||
if audio_frame:
|
||||
await self.push_audio_frame(audio_frame)
|
||||
self.reset_watchdog()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
@@ -432,6 +434,7 @@ class SmallWebRTCInputTransport(BaseInputTransport):
|
||||
async def _receive_video(self):
|
||||
try:
|
||||
async for video_frame in self._client.read_video_frame():
|
||||
self.start_watchdog()
|
||||
if video_frame:
|
||||
await self.push_video_frame(video_frame)
|
||||
|
||||
@@ -450,6 +453,7 @@ class SmallWebRTCInputTransport(BaseInputTransport):
|
||||
await self.push_video_frame(image_frame)
|
||||
# Remove from pending requests
|
||||
del self._image_requests[req_id]
|
||||
self.reset_watchdog()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
@@ -415,6 +415,7 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
logger.info("Audio input task started")
|
||||
while True:
|
||||
audio_data = await self._client.get_next_audio_frame()
|
||||
self.start_watchdog()
|
||||
if audio_data:
|
||||
audio_frame_event, participant_id = audio_data
|
||||
pipecat_audio_frame = await self._convert_livekit_audio_to_pipecat(
|
||||
@@ -427,6 +428,7 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
num_channels=pipecat_audio_frame.num_channels,
|
||||
)
|
||||
await self.push_audio_frame(input_audio_frame)
|
||||
self.reset_watchdog()
|
||||
|
||||
async def _convert_livekit_audio_to_pipecat(
|
||||
self, audio_frame_event: rtc.AudioFrameEvent
|
||||
|
||||
@@ -5,15 +5,30 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Coroutine, Dict, Optional, Sequence, Set
|
||||
from dataclasses import dataclass
|
||||
from typing import Coroutine, Dict, List, Optional, Sequence
|
||||
|
||||
from loguru import logger
|
||||
|
||||
WATCHDOG_TIMEOUT = 5.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskManagerParams:
|
||||
loop: asyncio.AbstractEventLoop
|
||||
enable_watchdog_logging: bool = False
|
||||
watchdog_timeout: float = WATCHDOG_TIMEOUT
|
||||
|
||||
|
||||
class BaseTaskManager(ABC):
|
||||
@abstractmethod
|
||||
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
def setup(self, params: TaskManagerParams):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -21,7 +36,14 @@ class BaseTaskManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
|
||||
def create_task(
|
||||
self,
|
||||
coroutine: Coroutine,
|
||||
name: str,
|
||||
*,
|
||||
enable_watchdog_logging: Optional[bool] = None,
|
||||
watchdog_timeout: Optional[float] = None,
|
||||
) -> asyncio.Task:
|
||||
"""
|
||||
Creates and schedules a new asyncio Task that runs the given coroutine.
|
||||
|
||||
@@ -31,6 +53,8 @@ class BaseTaskManager(ABC):
|
||||
loop (asyncio.AbstractEventLoop): The event loop to use for creating the task.
|
||||
coroutine (Coroutine): The coroutine to be executed within the task.
|
||||
name (str): The name to assign to the task for identification.
|
||||
enable_watchdog_logging(bool): whether this task should log watchdog processing times.
|
||||
watchdog_timeout(float): watchdog timer timeout for this task.
|
||||
|
||||
Returns:
|
||||
asyncio.Task: The created task object.
|
||||
@@ -73,21 +97,64 @@ class BaseTaskManager(ABC):
|
||||
"""Returns the list of currently created/registered tasks."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_watchdog(self, task: asyncio.Task):
|
||||
"""Starts the given task watchdog timer. If not reset, a warning will be
|
||||
logged indicating the task is stalling.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_watchdog(self, task: asyncio.Task):
|
||||
"""Resets the given task watchdog timer. If not reset, a warning will be
|
||||
logged indicating the task is stalling.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskData:
|
||||
task: asyncio.Task
|
||||
watchdog_start: asyncio.Event
|
||||
watchdog_timer: asyncio.Event
|
||||
enable_watchdog_logging: bool
|
||||
watchdog_timeout: float
|
||||
|
||||
|
||||
class TaskManager(BaseTaskManager):
|
||||
def __init__(self) -> None:
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._tasks: Dict[str, TaskData] = {}
|
||||
self._params: Optional[TaskManagerParams] = None
|
||||
self._watchdog_tasks: List[asyncio.Task] = []
|
||||
|
||||
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
self._loop = loop
|
||||
def setup(self, params: TaskManagerParams):
|
||||
if not self._params:
|
||||
self._params = params
|
||||
|
||||
async def cleanup(self):
|
||||
for task in self._watchdog_tasks:
|
||||
try:
|
||||
task.cancel()
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
# This is expected, no need to re-raise.
|
||||
pass
|
||||
|
||||
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
if not self._loop:
|
||||
raise Exception("TaskManager missing event loop, use TaskManager.set_event_loop().")
|
||||
return self._loop
|
||||
if not self._params:
|
||||
raise Exception("TaskManager is not setup: unable to get event loop")
|
||||
return self._params.loop
|
||||
|
||||
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
|
||||
def create_task(
|
||||
self,
|
||||
coroutine: Coroutine,
|
||||
name: str,
|
||||
*,
|
||||
enable_watchdog_logging: Optional[bool] = None,
|
||||
watchdog_timeout: Optional[float] = None,
|
||||
) -> asyncio.Task:
|
||||
"""
|
||||
Creates and schedules a new asyncio Task that runs the given coroutine.
|
||||
|
||||
@@ -97,6 +164,8 @@ class TaskManager(BaseTaskManager):
|
||||
loop (asyncio.AbstractEventLoop): The event loop to use for creating the task.
|
||||
coroutine (Coroutine): The coroutine to be executed within the task.
|
||||
name (str): The name to assign to the task for identification.
|
||||
enable_watchdog_logging(bool): whether this task should log watchdog processing time.
|
||||
watchdog_timeout(float): watchdog timer timeout for this task.
|
||||
|
||||
Returns:
|
||||
asyncio.Task: The created task object.
|
||||
@@ -112,12 +181,26 @@ class TaskManager(BaseTaskManager):
|
||||
except Exception as e:
|
||||
logger.exception(f"{name}: unexpected exception: {e}")
|
||||
|
||||
if not self._loop:
|
||||
raise Exception("TaskManager missing event loop, use TaskManager.set_event_loop().")
|
||||
if not self._params:
|
||||
raise Exception("TaskManager is not setup: unable to get event loop")
|
||||
|
||||
task = self._loop.create_task(run_coroutine())
|
||||
task = self._params.loop.create_task(run_coroutine())
|
||||
task.set_name(name)
|
||||
self._add_task(task)
|
||||
self._add_task(
|
||||
TaskData(
|
||||
task=task,
|
||||
watchdog_start=asyncio.Event(),
|
||||
watchdog_timer=asyncio.Event(),
|
||||
enable_watchdog_logging=(
|
||||
enable_watchdog_logging
|
||||
if enable_watchdog_logging
|
||||
else self._params.enable_watchdog_logging
|
||||
),
|
||||
watchdog_timeout=(
|
||||
watchdog_timeout if watchdog_timeout else self._params.watchdog_timeout
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.trace(f"{name}: task created")
|
||||
return task
|
||||
|
||||
@@ -165,6 +248,8 @@ class TaskManager(BaseTaskManager):
|
||||
name = task.get_name()
|
||||
task.cancel()
|
||||
try:
|
||||
# Make sure to reset watchdog if a task is cancelled.
|
||||
self.reset_watchdog(task)
|
||||
if timeout:
|
||||
await asyncio.wait_for(task, timeout=timeout)
|
||||
else:
|
||||
@@ -184,11 +269,43 @@ class TaskManager(BaseTaskManager):
|
||||
|
||||
def current_tasks(self) -> Sequence[asyncio.Task]:
|
||||
"""Returns the list of currently created/registered tasks."""
|
||||
return list(self._tasks.values())
|
||||
return [data.task for data in self._tasks.values()]
|
||||
|
||||
def _add_task(self, task: asyncio.Task):
|
||||
def start_watchdog(self, task: asyncio.Task):
|
||||
"""Starts the given task watchdog timer. If not reset, a warning will be
|
||||
logged indicating the task is stalling. If the timer was already started
|
||||
a warning will be logged.
|
||||
|
||||
"""
|
||||
name = task.get_name()
|
||||
self._tasks[name] = task
|
||||
if name in self._tasks:
|
||||
if self._tasks[name].watchdog_start.is_set():
|
||||
logger.warning(f"Watchdog timer for task {name} already started")
|
||||
else:
|
||||
self._tasks[name].watchdog_timer.clear()
|
||||
self._tasks[name].watchdog_start.set()
|
||||
else:
|
||||
logger.warning(f"Unable to start watchdog timer: task {name} does not exist")
|
||||
|
||||
def reset_watchdog(self, task: asyncio.Task):
|
||||
"""Resets the given task watchdog timer. If not reset, a warning will be
|
||||
logged indicating the task is stalling.
|
||||
|
||||
"""
|
||||
name = task.get_name()
|
||||
if name in self._tasks:
|
||||
self._tasks[name].watchdog_start.clear()
|
||||
self._tasks[name].watchdog_timer.set()
|
||||
else:
|
||||
logger.warning(f"Unable to reset watchdog timer: task {name} does not exist")
|
||||
|
||||
def _add_task(self, task_data: TaskData):
|
||||
name = task_data.task.get_name()
|
||||
self._tasks[name] = task_data
|
||||
watchdog_task = self.get_event_loop().create_task(
|
||||
self._watchdog_task_handler(self._tasks[name])
|
||||
)
|
||||
self._watchdog_tasks.append(watchdog_task)
|
||||
|
||||
def _remove_task(self, task: asyncio.Task):
|
||||
name = task.get_name()
|
||||
@@ -196,3 +313,33 @@ class TaskManager(BaseTaskManager):
|
||||
del self._tasks[name]
|
||||
except KeyError as e:
|
||||
logger.trace(f"{name}: unable to remove task (already removed?): {e}")
|
||||
|
||||
async def _watchdog_task_handler(self, task_data: TaskData):
|
||||
name = task_data.task.get_name()
|
||||
start = task_data.watchdog_start
|
||||
timer = task_data.watchdog_timer
|
||||
enable_watchdog_logging = task_data.enable_watchdog_logging
|
||||
watchdog_timeout = task_data.watchdog_timeout
|
||||
|
||||
async def wait_for_reset():
|
||||
waiting = True
|
||||
while waiting:
|
||||
try:
|
||||
start_time = time.time()
|
||||
await asyncio.wait_for(timer.wait(), timeout=watchdog_timeout)
|
||||
total_time = time.time() - start_time
|
||||
if enable_watchdog_logging:
|
||||
logger.debug(f"{name} task processing time: {total_time:.20f}")
|
||||
waiting = False
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"{name}: task is taking too long {WATCHDOG_TIMEOUT} second(s) (forgot to reset watchdog?)"
|
||||
)
|
||||
finally:
|
||||
timer.clear()
|
||||
|
||||
while True:
|
||||
# Wait for the user to start the watchdog timer.
|
||||
await start.wait()
|
||||
# Now, waiting for the task to finish.
|
||||
await wait_for_reset()
|
||||
|
||||
@@ -17,6 +17,7 @@ from pipecat.frames.frames import (
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -96,11 +97,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_task_single(self):
|
||||
pipeline = Pipeline([IdentityFilter()])
|
||||
task = PipelineTask(pipeline)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
await task.queue_frame(TextFrame(text="Hello!"))
|
||||
await task.queue_frames([TextFrame(text="Bye!"), EndFrame()])
|
||||
await task.run()
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
assert task.has_finished()
|
||||
|
||||
async def test_task_observers(self):
|
||||
@@ -116,10 +116,9 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline, observers=[CustomObserver()])
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
await task.queue_frames([TextFrame(text="Hello Downstream!"), EndFrame()])
|
||||
await task.run()
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
assert frame_received
|
||||
|
||||
async def test_task_add_observer(self):
|
||||
@@ -156,8 +155,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
observer1 = CustomAddObserver1()
|
||||
task.add_observer(observer1)
|
||||
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
async def delayed_add_observer():
|
||||
observer2 = CustomAddObserver2()
|
||||
# Wait after the pipeline is started and add another observer.
|
||||
@@ -176,7 +173,9 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
# Finally end the pipeline.
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
await asyncio.gather(task.run(), delayed_add_observer())
|
||||
await asyncio.gather(
|
||||
task.run(PipelineTaskParams(loop=asyncio.get_event_loop())), delayed_add_observer()
|
||||
)
|
||||
|
||||
assert frame_received
|
||||
assert frame_count_1 == 1
|
||||
@@ -189,7 +188,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
@task.event_handler("on_pipeline_started")
|
||||
async def on_pipeline_started(task, frame: StartFrame):
|
||||
@@ -202,7 +200,7 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
end_received = True
|
||||
|
||||
await task.queue_frame(EndFrame())
|
||||
await task.run()
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
|
||||
assert start_received
|
||||
assert end_received
|
||||
@@ -213,7 +211,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
@task.event_handler("on_pipeline_stopped")
|
||||
async def on_pipeline_ended(task, frame: StopFrame):
|
||||
@@ -221,7 +218,7 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
stop_received = True
|
||||
|
||||
await task.queue_frame(StopFrame())
|
||||
await task.run()
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
|
||||
assert stop_received
|
||||
|
||||
@@ -232,7 +229,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline, cancel_on_idle_timeout=False)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
task.set_reached_upstream_filter((TextFrame,))
|
||||
task.set_reached_downstream_filter((TextFrame,))
|
||||
|
||||
@@ -254,7 +250,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
await task.queue_frame(TextFrame(text="Hello Downstream!"))
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task.run()), timeout=1.0)
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))),
|
||||
timeout=1.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
@@ -282,13 +281,15 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
observers=[heartbeats_observer],
|
||||
cancel_on_idle_timeout=False,
|
||||
)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
expected_heartbeats = 1.0 / 0.2
|
||||
|
||||
await task.queue_frame(TextFrame(text="Hello!"))
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task.run()), timeout=1.0)
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))),
|
||||
timeout=1.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
assert heartbeats_counter == expected_heartbeats
|
||||
@@ -297,17 +298,18 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline, idle_timeout_secs=0.2)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
await task.run()
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
assert True
|
||||
|
||||
async def test_no_idle_task(self):
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline, idle_timeout_secs=0.2, cancel_on_idle_timeout=False)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task.run()), timeout=0.3)
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))),
|
||||
timeout=0.3,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
assert True
|
||||
else:
|
||||
@@ -324,15 +326,13 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
),
|
||||
idle_timeout_secs=0.3,
|
||||
)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
await task.run()
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
assert True
|
||||
|
||||
async def test_idle_task_event_handler_no_frames(self):
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline, idle_timeout_secs=0.2, cancel_on_idle_timeout=False)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
idle_timeout = False
|
||||
|
||||
@@ -342,14 +342,13 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
idle_timeout = True
|
||||
await task.cancel()
|
||||
|
||||
await task.run()
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
assert idle_timeout
|
||||
|
||||
async def test_idle_task_event_handler_quiet_user(self):
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
task = PipelineTask(pipeline, idle_timeout_secs=0.2, cancel_on_idle_timeout=False)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
idle_timeout = 0
|
||||
|
||||
@@ -373,7 +372,9 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await asyncio.gather(send_audio(), task.run())
|
||||
await asyncio.gather(
|
||||
send_audio(), task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
)
|
||||
assert idle_timeout == 1
|
||||
|
||||
async def test_idle_task_frames(self):
|
||||
@@ -387,7 +388,6 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
idle_timeout_secs=idle_timeout_secs,
|
||||
idle_timeout_frames=(TextFrame,),
|
||||
)
|
||||
task.set_event_loop(asyncio.get_event_loop())
|
||||
|
||||
async def delayed_frames():
|
||||
await asyncio.sleep(sleep_time_secs)
|
||||
@@ -399,7 +399,10 @@ class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
tasks = {asyncio.create_task(task.run()), asyncio.create_task(delayed_frames())}
|
||||
tasks = [
|
||||
asyncio.create_task(task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))),
|
||||
asyncio.create_task(delayed_frames()),
|
||||
]
|
||||
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user