From af90b8b4fa573bdf556bd9cec82db61707dd6579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 24 Jan 2025 23:39:18 -0800 Subject: [PATCH] utils: add wait_for_task() --- src/pipecat/pipeline/task.py | 4 ++-- src/pipecat/transports/base_output.py | 5 +++-- src/pipecat/utils/utils.py | 31 ++++++++++++++++++++++----- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 8d659bb90..3a814e15d 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -30,7 +30,7 @@ from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.base_task import BaseTask from pipecat.pipeline.task_observer import TaskObserver from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.utils.utils import cancel_task, create_task, obj_count, obj_id +from pipecat.utils.utils import cancel_task, create_task, obj_count, obj_id, wait_for_task HEARTBEAT_SECONDS = 1.0 HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 5 @@ -165,7 +165,7 @@ class PipelineTask(BaseTask): """ try: push_task = self._create_tasks() - await asyncio.gather(push_task) + await wait_for_task(push_task) except asyncio.CancelledError: # We are awaiting on the push task and it might be cancelled # (e.g. Ctrl-C). This means we will get a CancelledError here as diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 5181914ed..1ec5ea7fa 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -36,6 +36,7 @@ from pipecat.frames.frames import ( from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.transports.base_transport import TransportParams from pipecat.utils.time import nanoseconds_to_seconds +from pipecat.utils.utils import wait_for_task class BaseOutputTransport(FrameProcessor): @@ -87,9 +88,9 @@ class BaseOutputTransport(FrameProcessor): # for these tasks before cancelling the camera and audio tasks below # because they might be still rendering. if self._sink_task: - await self._sink_task + await wait_for_task(self._sink_task) if self._sink_clock_task: - await self._sink_clock_task + await wait_for_task(self._sink_clock_task) # We can now cancel the camera task. await self._cancel_camera_task() diff --git a/src/pipecat/utils/utils.py b/src/pipecat/utils/utils.py index 045308c57..12703c37c 100644 --- a/src/pipecat/utils/utils.py +++ b/src/pipecat/utils/utils.py @@ -48,7 +48,7 @@ def create_task(loop: asyncio.AbstractEventLoop, coroutine: Coroutine, name: str try: await coroutine except asyncio.CancelledError: - logger.trace(f"{name}: cancelling task") + logger.trace(f"{name}: task cancelled") # Re-raise the exception to ensure the task is cancelled. raise except Exception as e: @@ -61,6 +61,26 @@ def create_task(loop: asyncio.AbstractEventLoop, coroutine: Coroutine, name: str return task +async def wait_for_task(task: asyncio.Task, timeout: Optional[float] = None): + name = task.get_name() + try: + if timeout: + await asyncio.wait_for(task, timeout=timeout) + else: + await task + except asyncio.TimeoutError: + logger.warning(f"{name}: timed out waiting for task to finish") + except asyncio.CancelledError: + logger.error(f"{name}: unexpected task cancellation") + except Exception as e: + logger.exception(f"{name}: unexpected exception while stopping task: {e}") + finally: + try: + _TASKS.remove(task) + except KeyError as e: + logger.error(f"{name}: error removing task (already removed?): {e}") + + async def cancel_task(task: asyncio.Task, timeout: Optional[float] = None): name = task.get_name() task.cancel() @@ -70,16 +90,17 @@ async def cancel_task(task: asyncio.Task, timeout: Optional[float] = None): else: await task except asyncio.TimeoutError: - logger.warning(f"{name}: timed out waiting for task to finish") + logger.warning(f"{name}: timed out waiting for task to cancel") except asyncio.CancelledError: # Here are sure the task is cancelled properly. - logger.trace(f"{name}: task cancelled") + pass + except Exception as e: + logger.exception(f"{name}: unexpected exception while cancelling task: {e}") + finally: try: _TASKS.remove(task) except KeyError as e: logger.error(f"{name}: error removing task (already removed?): {e}") - except Exception as e: - logger.exception(f"{name}: unexpected exception while cancelling task: {e}") def current_tasks() -> Set[asyncio.Task]: