diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 3e11a7a19..b43aaaeae 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -23,7 +23,7 @@ from pipecat.audio.vad.vad_analyzer import VADParams from pipecat.clocks.base_clock import BaseClock from pipecat.metrics.metrics import MetricsData from pipecat.transcriptions.language import Language -from pipecat.utils.asyncio import TaskManager +from pipecat.utils.asyncio import BaseTaskManager from pipecat.utils.time import nanoseconds_to_str from pipecat.utils.utils import obj_count, obj_id @@ -438,7 +438,7 @@ class StartFrame(SystemFrame): """This is the first frame that should be pushed down a pipeline.""" clock: BaseClock - task_manager: TaskManager + task_manager: BaseTaskManager audio_in_sample_rate: int = 16000 audio_out_sample_rate: int = 24000 allow_interruptions: bool = False diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index a2aad1e93..da18cca5d 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -5,7 +5,7 @@ # import asyncio -from typing import Any, AsyncIterable, Dict, Iterable, List +from typing import Any, AsyncIterable, Dict, Iterable, List, Optional from loguru import logger from pydantic import BaseModel, ConfigDict @@ -31,7 +31,7 @@ from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.base_task import BaseTask from pipecat.pipeline.task_observer import TaskObserver from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.utils.asyncio import TaskManager +from pipecat.utils.asyncio import BaseTaskManager, TaskManager from pipecat.utils.utils import obj_count, obj_id HEARTBEAT_SECONDS = 1.0 @@ -134,6 +134,7 @@ class PipelineTask(BaseTask): params: PipelineParams = PipelineParams(), observers: List[BaseObserver] = [], clock: BaseClock = SystemClock(), + task_manager: Optional[BaseTaskManager] = None, check_dangling_tasks: bool = True, ): self._id: int = obj_id() @@ -174,7 +175,7 @@ class PipelineTask(BaseTask): self._sink = PipelineTaskSink(self._down_queue) pipeline.link(self._sink) - self._task_manager = TaskManager() + self._task_manager = task_manager or TaskManager() self._observer = TaskObserver(observers=observers, task_manager=self._task_manager) diff --git a/src/pipecat/pipeline/task_observer.py b/src/pipecat/pipeline/task_observer.py index abcbf513b..122038386 100644 --- a/src/pipecat/pipeline/task_observer.py +++ b/src/pipecat/pipeline/task_observer.py @@ -12,7 +12,7 @@ from attr import dataclass from pipecat.frames.frames import Frame from pipecat.observers.base_observer import BaseObserver from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.utils.asyncio import TaskManager +from pipecat.utils.asyncio import BaseTaskManager from pipecat.utils.utils import obj_count, obj_id @@ -55,7 +55,7 @@ class TaskObserver(BaseObserver): """ - def __init__(self, *, observers: List[BaseObserver] = [], task_manager: TaskManager): + def __init__(self, *, observers: List[BaseObserver] = [], task_manager: BaseTaskManager): self._id: int = obj_id() self._name: str = f"{self.__class__.__name__}#{obj_count(self)}" self._observers = observers diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 20a5071f4..335585ec3 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -23,7 +23,7 @@ from pipecat.frames.frames import ( ) from pipecat.metrics.metrics import LLMTokenUsage, MetricsData from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics -from pipecat.utils.asyncio import TaskManager +from pipecat.utils.asyncio import BaseTaskManager from pipecat.utils.utils import obj_count, obj_id @@ -52,7 +52,7 @@ class FrameProcessor: self._clock: Optional[BaseClock] = None # Task Manager - self._task_manager: Optional[TaskManager] = None + self._task_manager: Optional[BaseTaskManager] = None # Other properties self._allow_interruptions = False @@ -192,7 +192,7 @@ class FrameProcessor: raise Exception(f"{self} Clock is still not initialized.") return self._clock - def get_task_manager(self) -> TaskManager: + def get_task_manager(self) -> BaseTaskManager: if not self._task_manager: raise Exception(f"{self} TaskManager is still not initialized.") return self._task_manager diff --git a/src/pipecat/transports/network/websocket_client.py b/src/pipecat/transports/network/websocket_client.py index d7db9c4bd..eb2b5cfb8 100644 --- a/src/pipecat/transports/network/websocket_client.py +++ b/src/pipecat/transports/network/websocket_client.py @@ -30,7 +30,7 @@ from pipecat.serializers.protobuf import ProtobufFrameSerializer from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams -from pipecat.utils.asyncio import TaskManager +from pipecat.utils.asyncio import BaseTaskManager class WebsocketClientParams(TransportParams): @@ -57,12 +57,12 @@ class WebsocketClientSession: self._callbacks = callbacks self._transport_name = transport_name - self._task_manager: Optional[TaskManager] = None + self._task_manager: Optional[BaseTaskManager] = None self._websocket: Optional[websockets.WebSocketClientProtocol] = None @property - def task_manager(self) -> TaskManager: + def task_manager(self) -> BaseTaskManager: if not self._task_manager: raise Exception( f"{self._transport_name}::WebsocketClientSession: TaskManager not initialized (pipeline not started?)" diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 0bb0dfd6d..1d31f348b 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -43,7 +43,7 @@ from pipecat.transcriptions.language import Language from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams -from pipecat.utils.asyncio import TaskManager +from pipecat.utils.asyncio import BaseTaskManager try: from daily import CallClient, Daily, EventHandler @@ -293,7 +293,7 @@ class DailyTransportClient(EventHandler): self._joined_event = asyncio.Event() self._leave_counter = 0 - self._task_manager: Optional[TaskManager] = None + self._task_manager: Optional[BaseTaskManager] = None # We use the executor to cleanup the client. We just do it from one # place, so only one thread is really needed. diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index 9d26e1bfc..7018ea520 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -27,7 +27,7 @@ from pipecat.processors.frame_processor import FrameDirection from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams -from pipecat.utils.asyncio import TaskManager +from pipecat.utils.asyncio import BaseTaskManager try: from livekit import rtc @@ -88,7 +88,7 @@ class LiveKitTransportClient: self._audio_tracks = {} self._audio_queue = asyncio.Queue() self._other_participant_has_joined = False - self._task_manager: Optional[TaskManager] = None + self._task_manager: Optional[BaseTaskManager] = None @property def participant_id(self) -> str: diff --git a/src/pipecat/utils/asyncio.py b/src/pipecat/utils/asyncio.py index 073ab0e50..acc4acec8 100644 --- a/src/pipecat/utils/asyncio.py +++ b/src/pipecat/utils/asyncio.py @@ -5,12 +5,76 @@ # import asyncio +from abc import ABC, abstractmethod from typing import Coroutine, Optional, Set from loguru import logger -class TaskManager: +class BaseTaskManager(ABC): + @abstractmethod + def set_event_loop(self, loop: asyncio.AbstractEventLoop): + pass + + @abstractmethod + def get_event_loop(self) -> asyncio.AbstractEventLoop: + pass + + @abstractmethod + def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task: + """ + Creates and schedules a new asyncio Task that runs the given coroutine. + + The task is added to a global set of created tasks. + + Args: + loop (asyncio.AbstractEventLoop): The event loop to use for creating the task. + coroutine (Coroutine): The coroutine to be executed within the task. + name (str): The name to assign to the task for identification. + + Returns: + asyncio.Task: The created task object. + """ + pass + + @abstractmethod + async def wait_for_task(self, task: asyncio.Task, timeout: Optional[float] = None): + """Wait for an asyncio.Task to complete with optional timeout handling. + + This function awaits the specified asyncio.Task and handles scenarios for + timeouts, cancellations, and other exceptions. It also ensures that the task + is removed from the set of registered tasks upon completion or failure. + + Args: + task (asyncio.Task): The asyncio Task to wait for. + timeout (Optional[float], optional): The maximum number of seconds + to wait for the task to complete. If None, waits indefinitely. + Defaults to None. + """ + pass + + @abstractmethod + async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None): + """Cancels the given asyncio Task and awaits its completion with an + optional timeout. + + This function removes the task from the set of registered tasks upon + completion or failure. + + Args: + task (asyncio.Task): The task to be cancelled. + timeout (Optional[float]): The optional timeout in seconds to wait for the task to cancel. + + """ + pass + + @abstractmethod + def current_tasks(self) -> Set[asyncio.Task]: + """Returns the list of currently created/registered tasks.""" + pass + + +class TaskManager(BaseTaskManager): def __init__(self) -> None: self._tasks: Set[asyncio.Task] = set() self._loop: Optional[asyncio.AbstractEventLoop] = None