introduce new BaseTaskManager

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-24 23:19:57 -08:00
parent fb7fe540f5
commit d2f006682c
8 changed files with 83 additions and 18 deletions

View File

@@ -23,7 +23,7 @@ from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.clocks.base_clock import BaseClock
from pipecat.metrics.metrics import MetricsData
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio import TaskManager
from pipecat.utils.asyncio import BaseTaskManager
from pipecat.utils.time import nanoseconds_to_str
from pipecat.utils.utils import obj_count, obj_id
@@ -438,7 +438,7 @@ class StartFrame(SystemFrame):
"""This is the first frame that should be pushed down a pipeline."""
clock: BaseClock
task_manager: TaskManager
task_manager: BaseTaskManager
audio_in_sample_rate: int = 16000
audio_out_sample_rate: int = 24000
allow_interruptions: bool = False

View File

@@ -5,7 +5,7 @@
#
import asyncio
from typing import Any, AsyncIterable, Dict, Iterable, List
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional
from loguru import logger
from pydantic import BaseModel, ConfigDict
@@ -31,7 +31,7 @@ from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.base_task import BaseTask
from pipecat.pipeline.task_observer import TaskObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.asyncio import TaskManager
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
from pipecat.utils.utils import obj_count, obj_id
HEARTBEAT_SECONDS = 1.0
@@ -134,6 +134,7 @@ class PipelineTask(BaseTask):
params: PipelineParams = PipelineParams(),
observers: List[BaseObserver] = [],
clock: BaseClock = SystemClock(),
task_manager: Optional[BaseTaskManager] = None,
check_dangling_tasks: bool = True,
):
self._id: int = obj_id()
@@ -174,7 +175,7 @@ class PipelineTask(BaseTask):
self._sink = PipelineTaskSink(self._down_queue)
pipeline.link(self._sink)
self._task_manager = TaskManager()
self._task_manager = task_manager or TaskManager()
self._observer = TaskObserver(observers=observers, task_manager=self._task_manager)

View File

@@ -12,7 +12,7 @@ from attr import dataclass
from pipecat.frames.frames import Frame
from pipecat.observers.base_observer import BaseObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.asyncio import TaskManager
from pipecat.utils.asyncio import BaseTaskManager
from pipecat.utils.utils import obj_count, obj_id
@@ -55,7 +55,7 @@ class TaskObserver(BaseObserver):
"""
def __init__(self, *, observers: List[BaseObserver] = [], task_manager: TaskManager):
def __init__(self, *, observers: List[BaseObserver] = [], task_manager: BaseTaskManager):
self._id: int = obj_id()
self._name: str = f"{self.__class__.__name__}#{obj_count(self)}"
self._observers = observers

View File

@@ -23,7 +23,7 @@ from pipecat.frames.frames import (
)
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
from pipecat.utils.asyncio import TaskManager
from pipecat.utils.asyncio import BaseTaskManager
from pipecat.utils.utils import obj_count, obj_id
@@ -52,7 +52,7 @@ class FrameProcessor:
self._clock: Optional[BaseClock] = None
# Task Manager
self._task_manager: Optional[TaskManager] = None
self._task_manager: Optional[BaseTaskManager] = None
# Other properties
self._allow_interruptions = False
@@ -192,7 +192,7 @@ class FrameProcessor:
raise Exception(f"{self} Clock is still not initialized.")
return self._clock
def get_task_manager(self) -> TaskManager:
def get_task_manager(self) -> BaseTaskManager:
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
return self._task_manager

View File

@@ -30,7 +30,7 @@ from pipecat.serializers.protobuf import ProtobufFrameSerializer
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.utils.asyncio import TaskManager
from pipecat.utils.asyncio import BaseTaskManager
class WebsocketClientParams(TransportParams):
@@ -57,12 +57,12 @@ class WebsocketClientSession:
self._callbacks = callbacks
self._transport_name = transport_name
self._task_manager: Optional[TaskManager] = None
self._task_manager: Optional[BaseTaskManager] = None
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
@property
def task_manager(self) -> TaskManager:
def task_manager(self) -> BaseTaskManager:
if not self._task_manager:
raise Exception(
f"{self._transport_name}::WebsocketClientSession: TaskManager not initialized (pipeline not started?)"

View File

@@ -43,7 +43,7 @@ from pipecat.transcriptions.language import Language
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.utils.asyncio import TaskManager
from pipecat.utils.asyncio import BaseTaskManager
try:
from daily import CallClient, Daily, EventHandler
@@ -293,7 +293,7 @@ class DailyTransportClient(EventHandler):
self._joined_event = asyncio.Event()
self._leave_counter = 0
self._task_manager: Optional[TaskManager] = None
self._task_manager: Optional[BaseTaskManager] = None
# We use the executor to cleanup the client. We just do it from one
# place, so only one thread is really needed.

View File

@@ -27,7 +27,7 @@ from pipecat.processors.frame_processor import FrameDirection
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.utils.asyncio import TaskManager
from pipecat.utils.asyncio import BaseTaskManager
try:
from livekit import rtc
@@ -88,7 +88,7 @@ class LiveKitTransportClient:
self._audio_tracks = {}
self._audio_queue = asyncio.Queue()
self._other_participant_has_joined = False
self._task_manager: Optional[TaskManager] = None
self._task_manager: Optional[BaseTaskManager] = None
@property
def participant_id(self) -> str:

View File

@@ -5,12 +5,76 @@
#
import asyncio
from abc import ABC, abstractmethod
from typing import Coroutine, Optional, Set
from loguru import logger
class TaskManager:
class BaseTaskManager(ABC):
@abstractmethod
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
pass
@abstractmethod
def get_event_loop(self) -> asyncio.AbstractEventLoop:
pass
@abstractmethod
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
"""
Creates and schedules a new asyncio Task that runs the given coroutine.
The task is added to a global set of created tasks.
Args:
loop (asyncio.AbstractEventLoop): The event loop to use for creating the task.
coroutine (Coroutine): The coroutine to be executed within the task.
name (str): The name to assign to the task for identification.
Returns:
asyncio.Task: The created task object.
"""
pass
@abstractmethod
async def wait_for_task(self, task: asyncio.Task, timeout: Optional[float] = None):
"""Wait for an asyncio.Task to complete with optional timeout handling.
This function awaits the specified asyncio.Task and handles scenarios for
timeouts, cancellations, and other exceptions. It also ensures that the task
is removed from the set of registered tasks upon completion or failure.
Args:
task (asyncio.Task): The asyncio Task to wait for.
timeout (Optional[float], optional): The maximum number of seconds
to wait for the task to complete. If None, waits indefinitely.
Defaults to None.
"""
pass
@abstractmethod
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None):
"""Cancels the given asyncio Task and awaits its completion with an
optional timeout.
This function removes the task from the set of registered tasks upon
completion or failure.
Args:
task (asyncio.Task): The task to be cancelled.
timeout (Optional[float]): The optional timeout in seconds to wait for the task to cancel.
"""
pass
@abstractmethod
def current_tasks(self) -> Set[asyncio.Task]:
"""Returns the list of currently created/registered tasks."""
pass
class TaskManager(BaseTaskManager):
def __init__(self) -> None:
self._tasks: Set[asyncio.Task] = set()
self._loop: Optional[asyncio.AbstractEventLoop] = None