improve task creation and cancellation

If a FrameProcessor needs to create a task it should use
FrameProcessor.create_task() and FrameProcessor.cancel_task(). This gives
Pipecat more control over all the tasks that are created in Pipecat.

Both functions internally use the utils module: utils.create_task() and
utils.cancel_task() which should also be used outside of FrameProcessors. That
is, unless strictly necessary, we should avoid using asyncio.create_task().
This commit is contained in:
Aleix Conchillo Flaqué
2025-01-24 11:29:21 -08:00
parent 179ddbea7d
commit d1a3f404a5
31 changed files with 467 additions and 552 deletions

View File

@@ -169,8 +169,7 @@ class OutputGate(FrameProcessor):
self._gate_task = self.get_event_loop().create_task(self._gate_task_handler())
async def _stop(self):
self._gate_task.cancel()
await self._gate_task
await self.cancel_task(self._gate_task)
async def _gate_task_handler(self):
while True:

View File

@@ -101,12 +101,12 @@ HIGH PRIORITY SIGNALS:
Examples:
# Complete Wh-question
[{"role": "assistant", "content": "I can help you learn."},
[{"role": "assistant", "content": "I can help you learn."},
{"role": "user", "content": "What's the fastest way to learn Spanish"}]
Output: YES
# Complete Yes/No question despite STT error
[{"role": "assistant", "content": "I know about planets."},
[{"role": "assistant", "content": "I know about planets."},
{"role": "user", "content": "Is is Jupiter the biggest planet"}]
Output: YES
@@ -118,12 +118,12 @@ Output: YES
Examples:
# Direct instruction
[{"role": "assistant", "content": "I can explain many topics."},
[{"role": "assistant", "content": "I can explain many topics."},
{"role": "user", "content": "Tell me about black holes"}]
Output: YES
# Action demand
[{"role": "assistant", "content": "I can help with math."},
[{"role": "assistant", "content": "I can help with math."},
{"role": "user", "content": "Solve this equation x plus 5 equals 12"}]
Output: YES
@@ -134,12 +134,12 @@ Output: YES
Examples:
# Specific answer
[{"role": "assistant", "content": "What's your favorite color?"},
[{"role": "assistant", "content": "What's your favorite color?"},
{"role": "user", "content": "I really like blue"}]
Output: YES
# Option selection
[{"role": "assistant", "content": "Would you prefer morning or evening?"},
[{"role": "assistant", "content": "Would you prefer morning or evening?"},
{"role": "user", "content": "Morning"}]
Output: YES
@@ -153,17 +153,17 @@ MEDIUM PRIORITY SIGNALS:
Examples:
# Self-correction reaching completion
[{"role": "assistant", "content": "What would you like to know?"},
[{"role": "assistant", "content": "What would you like to know?"},
{"role": "user", "content": "Tell me about... no wait, explain how rainbows form"}]
Output: YES
# Topic change with complete thought
[{"role": "assistant", "content": "The weather is nice today."},
[{"role": "assistant", "content": "The weather is nice today."},
{"role": "user", "content": "Actually can you tell me who invented the telephone"}]
Output: YES
# Mid-sentence completion
[{"role": "assistant", "content": "Hello I'm ready."},
[{"role": "assistant", "content": "Hello I'm ready."},
{"role": "user", "content": "What's the capital of? France"}]
Output: YES
@@ -175,12 +175,12 @@ Output: YES
Examples:
# Acknowledgment
[{"role": "assistant", "content": "Should we talk about history?"},
[{"role": "assistant", "content": "Should we talk about history?"},
{"role": "user", "content": "Sure"}]
Output: YES
# Disagreement with completion
[{"role": "assistant", "content": "Is that what you meant?"},
[{"role": "assistant", "content": "Is that what you meant?"},
{"role": "user", "content": "No not really"}]
Output: YES
@@ -194,12 +194,12 @@ LOW PRIORITY SIGNALS:
Examples:
# Word repetition but complete
[{"role": "assistant", "content": "I can help with that."},
[{"role": "assistant", "content": "I can help with that."},
{"role": "user", "content": "What what is the time right now"}]
Output: YES
# Missing punctuation but complete
[{"role": "assistant", "content": "I can explain that."},
[{"role": "assistant", "content": "I can explain that."},
{"role": "user", "content": "Please tell me how computers work"}]
Output: YES
@@ -211,12 +211,12 @@ Output: YES
Examples:
# Filler words but complete
[{"role": "assistant", "content": "What would you like to know?"},
[{"role": "assistant", "content": "What would you like to know?"},
{"role": "user", "content": "Um uh how do airplanes fly"}]
Output: YES
# Thinking pause but incomplete
[{"role": "assistant", "content": "I can explain anything."},
[{"role": "assistant", "content": "I can explain anything."},
{"role": "user", "content": "Well um I want to know about the"}]
Output: NO
@@ -241,17 +241,17 @@ DECISION RULES:
Examples:
# Incomplete despite corrections
[{"role": "assistant", "content": "What would you like to know about?"},
[{"role": "assistant", "content": "What would you like to know about?"},
{"role": "user", "content": "Can you tell me about"}]
Output: NO
# Complete despite multiple artifacts
[{"role": "assistant", "content": "I can help you learn."},
[{"role": "assistant", "content": "I can help you learn."},
{"role": "user", "content": "How do you I mean what's the best way to learn programming"}]
Output: YES
# Trailing off incomplete
[{"role": "assistant", "content": "I can explain anything."},
[{"role": "assistant", "content": "I can explain anything."},
{"role": "user", "content": "I was wondering if you could tell me why"}]
Output: NO
"""
@@ -374,8 +374,7 @@ class OutputGate(FrameProcessor):
self._gate_task = self.get_event_loop().create_task(self._gate_task_handler())
async def _stop(self):
self._gate_task.cancel()
await self._gate_task
await cancel_task(self._gate_task)
async def _gate_task_handler(self):
while True:

View File

@@ -44,9 +44,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
)
from pipecat.processors.filters.function_filter import FunctionFilter
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.user_idle_processor import UserIdleProcessor
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.deepgram import DeepgramSTTService
from pipecat.services.google import GoogleLLMContext, GoogleLLMService
from pipecat.sync.base_notifier import BaseNotifier
from pipecat.sync.event_notifier import EventNotifier
@@ -440,11 +438,11 @@ class CompletenessCheck(FrameProcessor):
if isinstance(frame, UserStartedSpeakingFrame):
if self._idle_task:
self._idle_task.cancel()
await self.cancel_task(self._idle_task)
elif isinstance(frame, TextFrame) and frame.text.startswith("YES"):
logger.debug("Completeness check YES")
if self._idle_task:
self._idle_task.cancel()
await self.cancel_task(self._idle_task)
await self.push_frame(UserStoppedSpeakingFrame())
await self._audio_accumulator.reset()
await self._notifier.notify()
@@ -602,8 +600,7 @@ class OutputGate(FrameProcessor):
self._gate_task = self.get_event_loop().create_task(self._gate_task_handler())
async def _stop(self):
self._gate_task.cancel()
await self._gate_task
await self.cancel_task(self._gate_task)
async def _gate_task_handler(self):
while True:

View File

@@ -159,5 +159,5 @@ class SileroVADAnalyzer(VADAnalyzer):
return new_confidence
except Exception as e:
# This comes from an empty audio array
logger.exception(f"Error analyzing audio with Silero VAD: {e}")
logger.error(f"Error analyzing audio with Silero VAD: {e}")
return 0

View File

@@ -150,22 +150,18 @@ class ParallelPipeline(BasePipeline):
async def _stop(self):
# The up task doesn't receive an EndFrame, so we just cancel it.
self._up_task.cancel()
await self._up_task
# The down tasks waits for the last EndFrame send by the internal
await self.cancel_task(self._up_task)
# The down tasks waits for the last EndFrame sent by the internal
# pipelines.
await self._down_task
async def _cancel(self):
self._up_task.cancel()
await self._up_task
self._down_task.cancel()
await self._down_task
await self.cancel_task(self._up_task)
await self.cancel_task(self._down_task)
async def _create_tasks(self):
loop = self.get_event_loop()
self._up_task = loop.create_task(self._process_up_queue())
self._down_task = loop.create_task(self._process_down_queue())
self._up_task = self.create_task(self._process_up_queue())
self._down_task = self.create_task(self._process_down_queue())
async def _drain_queues(self):
while not self._up_queue.empty:
@@ -185,32 +181,26 @@ class ParallelPipeline(BasePipeline):
async def _process_up_queue(self):
while True:
try:
frame = await self._up_queue.get()
await self._parallel_push_frame(frame, FrameDirection.UPSTREAM)
self._up_queue.task_done()
except asyncio.CancelledError:
break
frame = await self._up_queue.get()
await self._parallel_push_frame(frame, FrameDirection.UPSTREAM)
self._up_queue.task_done()
async def _process_down_queue(self):
running = True
while running:
try:
frame = await self._down_queue.get()
frame = await self._down_queue.get()
endframe_counter = self._endframe_counter.get(frame.id, 0)
endframe_counter = self._endframe_counter.get(frame.id, 0)
# If we have a counter, decrement it.
if endframe_counter > 0:
self._endframe_counter[frame.id] -= 1
endframe_counter = self._endframe_counter[frame.id]
# If we have a counter, decrement it.
if endframe_counter > 0:
self._endframe_counter[frame.id] -= 1
endframe_counter = self._endframe_counter[frame.id]
# If we don't have a counter or we reached 0, push the frame.
if endframe_counter == 0:
await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM)
# If we don't have a counter or we reached 0, push the frame.
if endframe_counter == 0:
await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM)
running = not (endframe_counter == 0 and isinstance(frame, EndFrame))
running = not (endframe_counter == 0 and isinstance(frame, EndFrame))
self._down_queue.task_done()
except asyncio.CancelledError:
break
self._down_queue.task_done()

View File

@@ -30,7 +30,7 @@ from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.base_task import BaseTask
from pipecat.pipeline.task_observer import TaskObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.utils import obj_count, obj_id
from pipecat.utils.utils import cancel_task, create_task, obj_count, obj_id
HEARTBEAT_SECONDS = 1.0
HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 5
@@ -49,7 +49,7 @@ class PipelineParams(BaseModel):
heartbeats_period_secs: float = HEARTBEAT_SECONDS
class Source(FrameProcessor):
class PipelineTaskSource(FrameProcessor):
"""This is the source processor that is linked at the beginning of the
pipeline given to the pipeline task. It allows us to easily push frames
downstream to the pipeline and also receive upstream frames coming from the
@@ -57,8 +57,8 @@ class Source(FrameProcessor):
"""
def __init__(self, up_queue: asyncio.Queue):
super().__init__()
def __init__(self, up_queue: asyncio.Queue, **kwargs):
super().__init__(**kwargs)
self._up_queue = up_queue
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -71,15 +71,15 @@ class Source(FrameProcessor):
await self.push_frame(frame, direction)
class Sink(FrameProcessor):
class PipelineTaskSink(FrameProcessor):
"""This is the sink processor that is linked at the end of the pipeline
given to the pipeline task. It allows us to receive downstream frames and
act on them, for example, waiting to receive an EndFrame.
"""
def __init__(self, down_queue: asyncio.Queue):
super().__init__()
def __init__(self, down_queue: asyncio.Queue, **kwargs):
super().__init__(**kwargs)
self._down_queue = down_queue
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -115,10 +115,10 @@ class PipelineTask(BaseTask):
# down queue.
self._endframe_event = asyncio.Event()
self._source = Source(self._up_queue)
self._source = PipelineTaskSource(self._up_queue)
self._source.link(pipeline)
self._sink = Sink(self._down_queue)
self._sink = PipelineTaskSink(self._down_queue)
pipeline.link(self._sink)
self._observer = TaskObserver(params.observers)
@@ -148,13 +148,22 @@ class PipelineTask(BaseTask):
# we want to cancel right away.
await self._source.push_frame(CancelFrame())
await self._cancel_tasks(True)
await self._cleanup()
async def run(self):
"""
Starts running the given pipeline.
"""
tasks = self._create_tasks()
await asyncio.gather(*tasks)
try:
push_task = self._create_tasks()
await asyncio.gather(push_task)
except asyncio.CancelledError:
# We are awaiting on the push task and it might be cancelled
# (e.g. Ctrl-C). This means we will get a CancelledError here as
# well, because you get a CancelledError in every place you are
# awaiting a task.
pass
await self._cancel_tasks(False)
self._finished = True
async def queue_frame(self, frame: Frame):
@@ -175,41 +184,44 @@ class PipelineTask(BaseTask):
await self.queue_frame(frame)
def _create_tasks(self):
tasks = []
self._process_up_task = asyncio.create_task(self._process_up_queue())
self._process_down_task = asyncio.create_task(self._process_down_queue())
self._process_push_task = asyncio.create_task(self._process_push_queue())
loop = asyncio.get_running_loop()
self._process_up_task = create_task(
loop, self._process_up_queue(), f"{self}::_process_up_queue"
)
self._process_down_task = create_task(
loop, self._process_down_queue(), f"{self}::_process_down_queue"
)
self._process_push_task = create_task(
loop, self._process_push_queue(), f"{self}::_process_push_queue"
)
tasks = [self._process_up_task, self._process_down_task, self._process_push_task]
return tasks
return self._process_push_task
def _maybe_start_heartbeat_tasks(self):
if self._params.enable_heartbeats:
self._heartbeat_push_task = asyncio.create_task(self._heartbeat_push_handler())
self._heartbeat_monitor_task = asyncio.create_task(self._heartbeat_monitor_handler())
loop = asyncio.get_running_loop()
self._heartbeat_push_task = create_task(
loop, self._heartbeat_push_handler(), f"{self}::_heartbeat_push_handler"
)
self._heartbeat_monitor_task = create_task(
loop, self._heartbeat_monitor_handler(), f"{self}::_heartbeat_monitor_handler"
)
async def _cancel_tasks(self, cancel_push: bool):
await self._maybe_cancel_heartbeat_tasks()
if cancel_push:
self._process_push_task.cancel()
await self._process_push_task
await cancel_task(self._process_push_task)
self._process_up_task.cancel()
await self._process_up_task
self._process_down_task.cancel()
await self._process_down_task
await cancel_task(self._process_up_task)
await cancel_task(self._process_down_task)
await self._observer.stop()
async def _maybe_cancel_heartbeat_tasks(self):
if self._params.enable_heartbeats:
self._heartbeat_push_task.cancel()
await self._heartbeat_push_task
self._heartbeat_monitor_task.cancel()
await self._heartbeat_monitor_task
await cancel_task(self._heartbeat_push_task)
await cancel_task(self._heartbeat_monitor_task)
def _initial_metrics_frame(self) -> MetricsFrame:
processors = self._pipeline.processors_with_metrics()
@@ -223,6 +235,11 @@ class PipelineTask(BaseTask):
await self._endframe_event.wait()
self._endframe_event.clear()
async def _cleanup(self):
await self._source.cleanup()
await self._pipeline.cleanup()
await self._sink.cleanup()
async def _process_push_queue(self):
"""This is the task that runs the pipeline for the first time by sending
a StartFrame and by pushing any other frames queued by the user. It runs
@@ -249,24 +266,16 @@ class PipelineTask(BaseTask):
running = True
should_cleanup = True
while running:
try:
frame = await self._push_queue.get()
await self._source.queue_frame(frame, FrameDirection.DOWNSTREAM)
if isinstance(frame, EndFrame):
await self._wait_for_endframe()
running = not isinstance(frame, (StopTaskFrame, EndFrame))
should_cleanup = not isinstance(frame, StopTaskFrame)
self._push_queue.task_done()
except asyncio.CancelledError:
break
frame = await self._push_queue.get()
await self._source.queue_frame(frame, FrameDirection.DOWNSTREAM)
if isinstance(frame, EndFrame):
await self._wait_for_endframe()
running = not isinstance(frame, (StopTaskFrame, EndFrame))
should_cleanup = not isinstance(frame, StopTaskFrame)
self._push_queue.task_done()
# Cleanup only if we need to.
if should_cleanup:
await self._source.cleanup()
await self._pipeline.cleanup()
await self._sink.cleanup()
# Finally, cancel internal tasks. We don't cancel the push tasks because
# that's us.
await self._cancel_tasks(False)
await self._cleanup()
async def _process_up_queue(self):
"""This is the task that processes frames coming upstream from the
@@ -276,26 +285,23 @@ class PipelineTask(BaseTask):
"""
while True:
try:
frame = await self._up_queue.get()
if isinstance(frame, EndTaskFrame):
# Tell the task we should end nicely.
await self.queue_frame(EndFrame())
elif isinstance(frame, CancelTaskFrame):
# Tell the task we should end right away.
frame = await self._up_queue.get()
if isinstance(frame, EndTaskFrame):
# Tell the task we should end nicely.
await self.queue_frame(EndFrame())
elif isinstance(frame, CancelTaskFrame):
# Tell the task we should end right away.
await self.queue_frame(CancelFrame())
elif isinstance(frame, StopTaskFrame):
await self.queue_frame(StopTaskFrame())
elif isinstance(frame, ErrorFrame):
logger.error(f"Error running app: {frame}")
if frame.fatal:
# Cancel all tasks downstream.
await self.queue_frame(CancelFrame())
elif isinstance(frame, StopTaskFrame):
# Tell the task we should stop.
await self.queue_frame(StopTaskFrame())
elif isinstance(frame, ErrorFrame):
logger.error(f"Error running app: {frame}")
if frame.fatal:
# Cancel all tasks downstream.
await self.queue_frame(CancelFrame())
# Tell the task we should stop.
await self.queue_frame(StopTaskFrame())
self._up_queue.task_done()
except asyncio.CancelledError:
break
self._up_queue.task_done()
async def _process_down_queue(self):
"""This tasks process frames coming downstream from the pipeline. For
@@ -305,29 +311,23 @@ class PipelineTask(BaseTask):
"""
while True:
try:
frame = await self._down_queue.get()
if isinstance(frame, EndFrame):
self._endframe_event.set()
elif isinstance(frame, HeartbeatFrame):
await self._heartbeat_queue.put(frame)
self._down_queue.task_done()
except asyncio.CancelledError:
break
frame = await self._down_queue.get()
if isinstance(frame, EndFrame):
self._endframe_event.set()
elif isinstance(frame, HeartbeatFrame):
await self._heartbeat_queue.put(frame)
self._down_queue.task_done()
async def _heartbeat_push_handler(self):
"""
This tasks pushes a heartbeat frame every heartbeat period.
"""
while True:
try:
# Don't use `queue_frame()` because if an EndFrame is queued the
# task will just stop waiting for the pipeline to finish not
# allowing more frames to be pushed.
await self._source.queue_frame(HeartbeatFrame(timestamp=self._clock.get_time()))
await asyncio.sleep(self._params.heartbeats_period_secs)
except asyncio.CancelledError:
break
# Don't use `queue_frame()` because if an EndFrame is queued the
# task will just stop waiting for the pipeline to finish not
# allowing more frames to be pushed.
await self._source.queue_frame(HeartbeatFrame(timestamp=self._clock.get_time()))
await asyncio.sleep(self._params.heartbeats_period_secs)
async def _heartbeat_monitor_handler(self):
"""This tasks monitors heartbeat frames. If a heartbeat frame has not
@@ -347,8 +347,6 @@ class PipelineTask(BaseTask):
logger.warning(
f"{self}: heartbeat frame not received for more than {wait_time} seconds"
)
except asyncio.CancelledError:
break
def __str__(self):
return self.name

View File

@@ -12,6 +12,7 @@ from attr import dataclass
from pipecat.frames.frames import Frame
from pipecat.observers.base_observer import BaseObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.utils import cancel_task, create_task, obj_count
@dataclass
@@ -54,13 +55,13 @@ class TaskObserver(BaseObserver):
"""
def __init__(self, observers: List[BaseObserver] = []):
self.name: str = f"{self.__class__.__name__}#{obj_count(self)}"
self._proxies: List[Proxy] = self._create_proxies(observers)
async def stop(self):
"""Stops all proxy observer tasks."""
for proxy in self._proxies:
proxy.task.cancel()
await proxy.task
await cancel_task(proxy.task)
async def on_push_frame(
self,
@@ -79,19 +80,24 @@ class TaskObserver(BaseObserver):
def _create_proxies(self, observers) -> List[Proxy]:
proxies = []
loop = asyncio.get_running_loop()
for observer in observers:
queue = asyncio.Queue()
task = asyncio.create_task(self._proxy_task_handler(queue, observer))
task = create_task(
loop,
self._proxy_task_handler(queue, observer),
f"{self}::{observer.__class__.__name__}",
)
proxy = Proxy(queue=queue, task=task, observer=observer)
proxies.append(proxy)
return proxies
async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver):
while True:
try:
data = await queue.get()
await observer.on_push_frame(
data.src, data.dst, data.frame, data.direction, data.timestamp
)
except asyncio.CancelledError:
break
data = await queue.get()
await observer.on_push_frame(
data.src, data.dst, data.frame, data.direction, data.timestamp
)
def __str__(self):
return self.name

View File

@@ -4,8 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@@ -38,18 +36,14 @@ class GatedOpenAILLMContextAggregator(FrameProcessor):
await self.push_frame(frame, direction)
async def _start(self):
self._gate_task = self.get_event_loop().create_task(self._gate_task_handler())
self._gate_task = self.create_task(self._gate_task_handler())
async def _stop(self):
self._gate_task.cancel()
await self._gate_task
await self.cancel_task(self._gate_task)
async def _gate_task_handler(self):
while True:
try:
await self._notifier.wait()
if self._last_context_frame:
await self.push_frame(self._last_context_frame)
self._last_context_frame = None
except asyncio.CancelledError:
break
await self._notifier.wait()
if self._last_context_frame:
await self.push_frame(self._last_context_frame)
self._last_context_frame = None

View File

@@ -6,15 +6,15 @@
import asyncio
import inspect
import sys
from enum import Enum
from typing import Awaitable, Callable, Optional
from typing import Awaitable, Callable, Coroutine, Optional
from loguru import logger
from pipecat.clocks.base_clock import BaseClock
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
@@ -24,7 +24,7 @@ from pipecat.frames.frames import (
)
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
from pipecat.utils.utils import obj_count, obj_id
from pipecat.utils.utils import cancel_task, create_task, obj_count, obj_id
class FrameDirection(Enum):
@@ -141,6 +141,13 @@ class FrameProcessor:
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
def create_task(self, coroutine: Coroutine) -> asyncio.Task:
name = f"{self}::{coroutine.cr_code.co_name}"
return create_task(self.get_event_loop(), coroutine, name)
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None):
await cancel_task(task, timeout)
async def cleanup(self):
await self.__cancel_input_task()
await self.__cancel_push_task()
@@ -188,7 +195,6 @@ class FrameProcessor:
async def resume_processing_frames(self):
logger.trace(f"{self}: resuming frame processing")
self.__input_event.set()
self.__should_block_frames = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, StartFrame):
@@ -283,61 +289,44 @@ class FrameProcessor:
def __create_input_task(self):
self.__should_block_frames = False
self.__input_queue = asyncio.Queue()
self.__input_frame_task = self.get_event_loop().create_task(
self.__input_frame_task_handler()
)
self.__input_event = asyncio.Event()
self.__input_frame_task = self.create_task(self.__input_frame_task_handler())
async def __cancel_input_task(self):
self.__input_frame_task.cancel()
await self.__input_frame_task
await self.cancel_task(self.__input_frame_task)
async def __input_frame_task_handler(self):
while True:
try:
if self.__should_block_frames:
logger.trace(f"{self}: frame processing paused")
await self.__input_event.wait()
self.__input_event.clear()
logger.trace(f"{self}: frame processing resumed")
if self.__should_block_frames:
logger.trace(f"{self}: frame processing paused")
await self.__input_event.wait()
self.__input_event.clear()
self.__should_block_frames = False
logger.trace(f"{self}: frame processing resumed")
(frame, direction, callback) = await self.__input_queue.get()
(frame, direction, callback) = await self.__input_queue.get()
# Process the frame.
await self.process_frame(frame, direction)
# Process the frame.
await self.process_frame(frame, direction)
# If this frame has an associated callback, call it now.
if callback:
await callback(self, frame, direction)
# If this frame has an associated callback, call it now.
if callback:
await callback(self, frame, direction)
self.__input_queue.task_done()
except asyncio.CancelledError:
logger.trace(f"{self}: cancelled input task")
break
except Exception as e:
logger.exception(f"{self}: Uncaught exception {e}")
await self.push_error(ErrorFrame(str(e)))
self.__input_queue.task_done()
def __create_push_task(self):
self.__push_queue = asyncio.Queue()
self.__push_frame_task = self.get_event_loop().create_task(self.__push_frame_task_handler())
self.__push_frame_task = self.create_task(self.__push_frame_task_handler())
async def __cancel_push_task(self):
self.__push_frame_task.cancel()
await self.__push_frame_task
await self.cancel_task(self.__push_frame_task)
async def __push_frame_task_handler(self):
while True:
try:
(frame, direction) = await self.__push_queue.get()
await self.__internal_push_frame(frame, direction)
self.__push_queue.task_done()
except asyncio.CancelledError:
logger.trace(f"{self}: cancelled push task")
break
except Exception as e:
logger.exception(f"{self}: Uncaught exception {e}")
await self.push_error(ErrorFrame(str(e)))
(frame, direction) = await self.__push_queue.get()
await self.__internal_push_frame(frame, direction)
self.__push_queue.task_done()
async def _call_event_handler(self, event_name: str, *args, **kwargs):
try:

View File

@@ -764,11 +764,11 @@ class RTVIProcessor(FrameProcessor):
# A task to process incoming action frames.
self._action_queue = asyncio.Queue()
self._action_task = self.get_event_loop().create_task(self._action_task_handler())
self._action_task = self.create_task(self._action_task_handler())
# A task to process incoming transport messages.
self._message_queue = asyncio.Queue()
self._message_task = self.get_event_loop().create_task(self._message_task_handler())
self._message_task = self.create_task(self._message_task_handler())
self._register_event_handler("on_bot_started")
self._register_event_handler("on_client_ready")
@@ -873,13 +873,11 @@ class RTVIProcessor(FrameProcessor):
async def _cancel_tasks(self):
if self._action_task:
self._action_task.cancel()
await self._action_task
await self.cancel_task(self._action_task)
self._action_task = None
if self._message_task:
self._message_task.cancel()
await self._message_task
await self.cancel_task(self._message_task)
self._message_task = None
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
@@ -888,21 +886,15 @@ class RTVIProcessor(FrameProcessor):
async def _action_task_handler(self):
while True:
try:
frame = await self._action_queue.get()
await self._handle_action(frame.message_id, frame.rtvi_action_run)
self._action_queue.task_done()
except asyncio.CancelledError:
break
frame = await self._action_queue.get()
await self._handle_action(frame.message_id, frame.rtvi_action_run)
self._action_queue.task_done()
async def _message_task_handler(self):
while True:
try:
message = await self._message_queue.get()
await self._handle_message(message)
self._message_queue.task_done()
except asyncio.CancelledError:
break
message = await self._message_queue.get()
await self._handle_message(message)
self._message_queue.task_done()
async def _handle_transport_message(self, frame: TransportMessageUrgentFrame):
try:

View File

@@ -49,12 +49,11 @@ class IdleFrameProcessor(FrameProcessor):
self._idle_event.set()
async def cleanup(self):
self._idle_task.cancel()
await self._idle_task
await self.cancel_task(self._idle_task)
def _create_idle_task(self):
self._idle_event = asyncio.Event()
self._idle_task = self.get_event_loop().create_task(self._idle_task_handler())
self._idle_task = self.create_task(self._idle_task_handler())
async def _idle_task_handler(self):
while True:
@@ -62,7 +61,5 @@ class IdleFrameProcessor(FrameProcessor):
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
except asyncio.TimeoutError:
await self._callback(self)
except asyncio.CancelledError:
break
finally:
self._idle_event.clear()

View File

@@ -103,7 +103,7 @@ class UserIdleProcessor(FrameProcessor):
def _create_idle_task(self) -> None:
"""Creates the idle task if it hasn't been created yet."""
if self._idle_task is None:
self._idle_task = self.get_event_loop().create_task(self._idle_task_handler())
self._idle_task = self.create_task(self._idle_task_handler())
@property
def retry_count(self) -> int:
@@ -113,11 +113,7 @@ class UserIdleProcessor(FrameProcessor):
async def _stop(self) -> None:
"""Stops and cleans up the idle monitoring task."""
if self._idle_task is not None:
self._idle_task.cancel()
try:
await self._idle_task
except asyncio.CancelledError:
pass # Expected when task is cancelled
await self.cancel_task(self._idle_task)
self._idle_task = None
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
@@ -178,7 +174,5 @@ class UserIdleProcessor(FrameProcessor):
if not should_continue:
await self._stop()
break
except asyncio.CancelledError:
break
finally:
self._idle_event.clear()

View File

@@ -253,20 +253,18 @@ class TTSService(AIService):
async def start(self, frame: StartFrame):
await super().start(frame)
if self._push_stop_frames:
self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler())
self._stop_frame_task = self.create_task(self._stop_frame_handler())
async def stop(self, frame: EndFrame):
await super().stop(frame)
if self._stop_frame_task:
self._stop_frame_task.cancel()
await self._stop_frame_task
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
if self._stop_frame_task:
self._stop_frame_task.cancel()
await self._stop_frame_task
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
async def _update_settings(self, settings: Dict[str, Any]):
@@ -364,23 +362,20 @@ class TTSService(AIService):
await self.push_frame(TTSTextFrame(text))
async def _stop_frame_handler(self):
try:
has_started = False
while True:
try:
frame = await asyncio.wait_for(
self._stop_frame_queue.get(), self._stop_frame_timeout_s
)
if isinstance(frame, TTSStartedFrame):
has_started = True
elif isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
has_started = False
except asyncio.TimeoutError:
if has_started:
await self.push_frame(TTSStoppedFrame())
has_started = False
except asyncio.CancelledError:
pass
has_started = False
while True:
try:
frame = await asyncio.wait_for(
self._stop_frame_queue.get(), self._stop_frame_timeout_s
)
if isinstance(frame, TTSStartedFrame):
has_started = True
elif isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
has_started = False
except asyncio.TimeoutError:
if has_started:
await self.push_frame(TTSStoppedFrame())
has_started = False
class WordTTSService(TTSService):
@@ -388,7 +383,7 @@ class WordTTSService(TTSService):
super().__init__(**kwargs)
self._initial_word_timestamp = -1
self._words_queue = asyncio.Queue()
self._words_task = self.get_event_loop().create_task(self._words_task_handler())
self._words_task = self.create_task(self._words_task_handler())
def start_word_timestamps(self):
if self._initial_word_timestamp == -1:
@@ -421,35 +416,29 @@ class WordTTSService(TTSService):
async def _stop_words_task(self):
if self._words_task:
self._words_task.cancel()
await self._words_task
await self.cancel_task(self._words_task)
self._words_task = None
async def _words_task_handler(self):
last_pts = 0
while True:
try:
(word, timestamp) = await self._words_queue.get()
if word == "Reset" and timestamp == 0:
self.reset_word_timestamps()
frame = None
elif word == "LLMFullResponseEndFrame" and timestamp == 0:
frame = LLMFullResponseEndFrame()
frame.pts = last_pts
elif word == "TTSStoppedFrame" and timestamp == 0:
frame = TTSStoppedFrame()
frame.pts = last_pts
else:
frame = TTSTextFrame(word)
frame.pts = self._initial_word_timestamp + timestamp
if frame:
last_pts = frame.pts
await self.push_frame(frame)
self._words_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.exception(f"{self} exception: {e}")
(word, timestamp) = await self._words_queue.get()
if word == "Reset" and timestamp == 0:
self.reset_word_timestamps()
frame = None
elif word == "LLMFullResponseEndFrame" and timestamp == 0:
frame = LLMFullResponseEndFrame()
frame.pts = last_pts
elif word == "TTSStoppedFrame" and timestamp == 0:
frame = TTSStoppedFrame()
frame.pts = last_pts
else:
frame = TTSTextFrame(word)
frame.pts = self._initial_word_timestamp + timestamp
if frame:
last_pts = frame.pts
await self.push_frame(frame)
self._words_queue.task_done()
class STTService(AIService):

View File

@@ -187,16 +187,13 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
async def _disconnect(self):
await self._disconnect_websocket()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
await self.cancel_task(self._receive_task)
self._receive_task = None
async def _connect_websocket(self):

View File

@@ -299,20 +299,16 @@ class ElevenLabsTTSService(WordTTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler())
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
self._keepalive_task = self.create_task(self._keepalive_task_handler())
async def _disconnect(self):
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
await self.cancel_task(self._receive_task)
self._receive_task = None
if self._keepalive_task:
self._keepalive_task.cancel()
await self._keepalive_task
await self.cancel_task(self._keepalive_task)
self._keepalive_task = None
await self._disconnect_websocket()
@@ -383,13 +379,8 @@ class ElevenLabsTTSService(WordTTSService, WebsocketService):
async def _keepalive_task_handler(self):
while True:
try:
await asyncio.sleep(10)
await self._send_text("")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await asyncio.sleep(10)
await self._send_text("")
async def _send_text(self, text: str):
if self._websocket:

View File

@@ -104,15 +104,12 @@ class FishAudioTTSService(TTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
async def _disconnect(self):
await self._disconnect_websocket()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
await self.cancel_task(self._receive_task)
self._receive_task = None
async def _connect_websocket(self):

View File

@@ -275,7 +275,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
)
await self.send_client_event(evt)
if self._transcribe_user_audio and self._context:
asyncio.create_task(self._handle_transcribe_user_audio(audio, self._context))
self.create_task(self._handle_transcribe_user_audio(audio, self._context))
async def _handle_transcribe_user_audio(self, audio, context):
text = await self._transcribe_audio(audio, context)
@@ -391,7 +391,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
uri = f"wss://{self.base_url}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"
logger.info(f"Connecting to {uri}")
self._websocket = await websockets.connect(uri=uri)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._receive_task = self.create_task(self._receive_task_handler())
config = events.Config.model_validate(
{
"setup": {
@@ -441,11 +441,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self._websocket.close()
self._websocket = None
if self._receive_task:
self._receive_task.cancel()
try:
await asyncio.wait_for(self._receive_task, timeout=1.0)
except asyncio.TimeoutError:
logger.warning("Timed out waiting for receive task to finish")
await self.cancel_task(self._receive_task, timeout=1.0)
self._receive_task = None
self._disconnecting = False
except Exception as e:
@@ -497,6 +493,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
pass
except asyncio.CancelledError:
logger.debug("websocket receive task cancelled")
raise
except Exception as e:
logger.error(f"{self} exception: {e}")
@@ -679,7 +676,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
self._bot_text_buffer = ""
if audio and self._transcribe_model_audio and self._context:
asyncio.create_task(self._handle_transcribe_model_audio(audio, self._context))
self.create_task(self._handle_transcribe_model_audio(audio, self._context))
elif text:
await self.push_frame(LLMFullResponseEndFrame())

View File

@@ -180,7 +180,7 @@ class GladiaSTTService(STTService):
await super().start(frame)
response = await self._setup_gladia()
self._websocket = await websockets.connect(response["url"])
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._receive_task = self.create_task(self._receive_task_handler())
async def stop(self, frame: EndFrame):
await super().stop(frame)

View File

@@ -113,16 +113,13 @@ class LmntTTSService(TTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
async def _disconnect(self):
await self._disconnect_websocket()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
await self.cancel_task(self._receive_task)
self._receive_task = None
async def _connect_websocket(self):

View File

@@ -277,7 +277,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
"OpenAI-Beta": "realtime=v1",
},
)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._receive_task = self.create_task(self._receive_task_handler())
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
@@ -291,11 +291,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
await self._websocket.close()
self._websocket = None
if self._receive_task:
self._receive_task.cancel()
try:
await asyncio.wait_for(self._receive_task, timeout=1.0)
except asyncio.TimeoutError:
logger.warning("Timed out waiting for receive task to finish")
await self.cancel_task(self._receive_task, timeout=1.0)
self._receive_task = None
self._disconnecting = False
except Exception as e:
@@ -332,40 +328,32 @@ class OpenAIRealtimeBetaLLMService(LLMService):
#
async def _receive_task_handler(self):
try:
async for message in self._websocket:
evt = events.parse_server_event(message)
if evt.type == "session.created":
await self._handle_evt_session_created(evt)
elif evt.type == "session.updated":
await self._handle_evt_session_updated(evt)
elif evt.type == "response.audio.delta":
await self._handle_evt_audio_delta(evt)
elif evt.type == "response.audio.done":
await self._handle_evt_audio_done(evt)
elif evt.type == "conversation.item.created":
await self._handle_evt_conversation_item_created(evt)
elif evt.type == "conversation.item.input_audio_transcription.completed":
await self.handle_evt_input_audio_transcription_completed(evt)
elif evt.type == "response.done":
await self._handle_evt_response_done(evt)
elif evt.type == "input_audio_buffer.speech_started":
await self._handle_evt_speech_started(evt)
elif evt.type == "input_audio_buffer.speech_stopped":
await self._handle_evt_speech_stopped(evt)
elif evt.type == "response.audio_transcript.delta":
await self._handle_evt_audio_transcript_delta(evt)
elif evt.type == "error":
await self._handle_evt_error(evt)
# errors are fatal, so exit the receive loop
return
else:
pass
except asyncio.CancelledError:
logger.debug("websocket receive task cancelled")
except Exception as e:
logger.error(f"{self} exception: {e}")
async for message in self._websocket:
evt = events.parse_server_event(message)
if evt.type == "session.created":
await self._handle_evt_session_created(evt)
elif evt.type == "session.updated":
await self._handle_evt_session_updated(evt)
elif evt.type == "response.audio.delta":
await self._handle_evt_audio_delta(evt)
elif evt.type == "response.audio.done":
await self._handle_evt_audio_done(evt)
elif evt.type == "conversation.item.created":
await self._handle_evt_conversation_item_created(evt)
elif evt.type == "conversation.item.input_audio_transcription.completed":
await self.handle_evt_input_audio_transcription_completed(evt)
elif evt.type == "response.done":
await self._handle_evt_response_done(evt)
elif evt.type == "input_audio_buffer.speech_started":
await self._handle_evt_speech_started(evt)
elif evt.type == "input_audio_buffer.speech_stopped":
await self._handle_evt_speech_stopped(evt)
elif evt.type == "response.audio_transcript.delta":
await self._handle_evt_audio_transcript_delta(evt)
elif evt.type == "error":
await self._handle_evt_error(evt)
# errors are fatal, so exit the receive loop
return
async def _handle_evt_session_created(self, evt):
# session.created is received right after connecting. Send a message

View File

@@ -165,16 +165,13 @@ class PlayHTTTSService(TTSService, WebsocketService):
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(
self._receive_task_handler(self.push_error)
)
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
async def _disconnect(self):
await self._disconnect_websocket()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
await self.cancel_task(self._receive_task)
self._receive_task = None
async def _connect_websocket(self):

View File

@@ -202,8 +202,8 @@ class ParakeetSTTService(STTService):
async def start(self, frame: StartFrame):
await super().start(frame)
self._thread_task = self.get_event_loop().create_task(self._thread_task_handler())
self._response_task = self.get_event_loop().create_task(self._response_task_handler())
self._thread_task = self.create_task(self._thread_task_handler())
self._response_task = self.create_task(self._response_task_handler())
self._response_queue = asyncio.Queue()
async def stop(self, frame: EndFrame):
@@ -215,10 +215,8 @@ class ParakeetSTTService(STTService):
await self._stop_tasks()
async def _stop_tasks(self):
self._thread_task.cancel()
await self._thread_task
self._response_task.cancel()
await self._response_task
await self.cancel_task(self._thread_task)
await self.cancel_task(self._response_task)
def _response_handler(self):
responses = self._asr_service.streaming_response_generator(
@@ -238,7 +236,7 @@ class ParakeetSTTService(STTService):
await asyncio.to_thread(self._response_handler)
except asyncio.CancelledError:
self._thread_running = False
pass
raise
async def _handle_response(self, response):
for result in response.results:
@@ -260,11 +258,8 @@ class ParakeetSTTService(STTService):
async def _response_task_handler(self):
while True:
try:
response = await self._response_queue.get()
await self._handle_response(response)
except asyncio.CancelledError:
break
response = await self._response_queue.get()
await self._handle_response(response)
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
await self._queue.put(audio)

View File

@@ -49,45 +49,33 @@ class SimliVideoService(FrameProcessor):
async def _start_connection(self):
await self._simli_client.Initialize()
# Create task to consume and process audio and video
self._audio_task = asyncio.create_task(self._consume_and_process_audio())
self._video_task = asyncio.create_task(self._consume_and_process_video())
self._audio_task = self.create_task(self._consume_and_process_audio())
self._video_task = self.create_task(self._consume_and_process_video())
async def _consume_and_process_audio(self):
try:
await self._pipecat_resampler_event.wait()
async for audio_frame in self._simli_client.getAudioStreamIterator():
resampled_frames = self._pipecat_resampler.resample(audio_frame)
for resampled_frame in resampled_frames:
await self.push_frame(
TTSAudioRawFrame(
audio=resampled_frame.to_ndarray().tobytes(),
sample_rate=self._pipecat_resampler.rate,
num_channels=1,
),
)
except Exception as e:
logger.exception(f"{self} exception: {e}")
except asyncio.CancelledError:
pass
await self._pipecat_resampler_event.wait()
async for audio_frame in self._simli_client.getAudioStreamIterator():
resampled_frames = self._pipecat_resampler.resample(audio_frame)
for resampled_frame in resampled_frames:
await self.push_frame(
TTSAudioRawFrame(
audio=resampled_frame.to_ndarray().tobytes(),
sample_rate=self._pipecat_resampler.rate,
num_channels=1,
),
)
async def _consume_and_process_video(self):
try:
await self._pipecat_resampler_event.wait()
async for video_frame in self._simli_client.getVideoStreamIterator(
targetFormat="rgb24"
):
# Process the video frame
convertedFrame: OutputImageRawFrame = OutputImageRawFrame(
image=video_frame.to_rgb().to_image().tobytes(),
size=(video_frame.width, video_frame.height),
format="RGB",
)
convertedFrame.pts = video_frame.pts
await self.push_frame(convertedFrame)
except Exception as e:
logger.exception(f"{self} exception: {e}")
except asyncio.CancelledError:
pass
await self._pipecat_resampler_event.wait()
async for video_frame in self._simli_client.getVideoStreamIterator(targetFormat="rgb24"):
# Process the video frame
convertedFrame: OutputImageRawFrame = OutputImageRawFrame(
image=video_frame.to_rgb().to_image().tobytes(),
size=(video_frame.width, video_frame.height),
format="RGB",
)
convertedFrame.pts = video_frame.pts
await self.push_frame(convertedFrame)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -128,8 +116,6 @@ class SimliVideoService(FrameProcessor):
async def _stop(self):
await self._simli_client.stop()
if self._audio_task:
self._audio_task.cancel()
await self._audio_task
await self.cancel_task(self._audio_task)
if self._video_task:
self._video_task.cancel()
await self._video_task
await self.cancel_task(self._video_task)

View File

@@ -85,10 +85,6 @@ class WebsocketService(ABC):
await self._receive_messages()
logger.debug(f"{self} connection established successfully")
retry_count = 0 # Reset counter on successful message receive
except asyncio.CancelledError:
break
except Exception as e:
retry_count += 1
if retry_count >= MAX_RETRIES:

View File

@@ -50,13 +50,12 @@ class BaseInputTransport(FrameProcessor):
# Create audio input queue and task if needed.
if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_in_queue = asyncio.Queue()
self._audio_task = self.get_event_loop().create_task(self._audio_task_handler())
self._audio_task = self.create_task(self._audio_task_handler())
async def stop(self, frame: EndFrame):
# Cancel and wait for the audio input task to finish.
if self._audio_task and (self._params.audio_in_enabled or self._params.vad_enabled):
self._audio_task.cancel()
await self._audio_task
await self.cancel_task(self._audio_task)
self._audio_task = None
# Stop audio filter.
if self._params.audio_in_filter:
@@ -65,8 +64,7 @@ class BaseInputTransport(FrameProcessor):
async def cancel(self, frame: CancelFrame):
# Cancel and wait for the audio input task to finish.
if self._audio_task and (self._params.audio_in_enabled or self._params.vad_enabled):
self._audio_task.cancel()
await self._audio_task
await self.cancel_task(self._audio_task)
self._audio_task = None
def vad_analyzer(self) -> VADAnalyzer | None:
@@ -173,27 +171,22 @@ class BaseInputTransport(FrameProcessor):
async def _audio_task_handler(self):
vad_state: VADState = VADState.QUIET
while True:
try:
frame: InputAudioRawFrame = await self._audio_in_queue.get()
frame: InputAudioRawFrame = await self._audio_in_queue.get()
audio_passthrough = True
audio_passthrough = True
# If an audio filter is available, run it before VAD.
if self._params.audio_in_filter:
frame.audio = await self._params.audio_in_filter.filter(frame.audio)
# If an audio filter is available, run it before VAD.
if self._params.audio_in_filter:
frame.audio = await self._params.audio_in_filter.filter(frame.audio)
# Check VAD and push event if necessary. We just care about
# changes from QUIET to SPEAKING and vice versa.
if self._params.vad_enabled:
vad_state = await self._handle_vad(frame, vad_state)
audio_passthrough = self._params.vad_audio_passthrough
# Check VAD and push event if necessary. We just care about
# changes from QUIET to SPEAKING and vice versa.
if self._params.vad_enabled:
vad_state = await self._handle_vad(frame, vad_state)
audio_passthrough = self._params.vad_audio_passthrough
# Push audio downstream if passthrough.
if audio_passthrough:
await self.push_frame(frame)
# Push audio downstream if passthrough.
if audio_passthrough:
await self.push_frame(frame)
self._audio_in_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.exception(f"{self} error reading audio frames: {e}")
self._audio_in_queue.task_done()

View File

@@ -217,22 +217,19 @@ class BaseOutputTransport(FrameProcessor):
#
def _create_sink_tasks(self):
loop = self.get_event_loop()
self._sink_queue = asyncio.Queue()
self._sink_task = loop.create_task(self._sink_task_handler())
self._sink_clock_queue = asyncio.PriorityQueue()
self._sink_clock_task = loop.create_task(self._sink_clock_task_handler())
self._sink_task = self.create_task(self._sink_task_handler())
self._sink_clock_task = self.create_task(self._sink_clock_task_handler())
async def _cancel_sink_tasks(self):
# Stop sink tasks.
if self._sink_task:
self._sink_task.cancel()
await self._sink_task
await self.cancel_task(self._sink_task)
self._sink_task = None
# Stop sink clock tasks.
if self._sink_clock_task:
self._sink_clock_task.cancel()
await self._sink_clock_task
await self.cancel_task(self._sink_clock_task)
self._sink_clock_task = None
async def _sink_frame_handler(self, frame: Frame):
@@ -269,7 +266,7 @@ class BaseOutputTransport(FrameProcessor):
self._sink_clock_queue.task_done()
except asyncio.CancelledError:
break
raise
except Exception as e:
logger.exception(f"{self} error processing sink clock queue: {e}")
@@ -317,49 +314,42 @@ class BaseOutputTransport(FrameProcessor):
return without_mixer(vad_stop_secs)
async def _sink_task_handler(self):
try:
async for frame in self._next_frame():
# Notify the bot started speaking upstream if necessary and that
# it's actually speaking.
if isinstance(frame, TTSAudioRawFrame):
await self._bot_started_speaking()
await self.push_frame(BotSpeakingFrame())
await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM)
async for frame in self._next_frame():
# Notify the bot started speaking upstream if necessary and that
# it's actually speaking.
if isinstance(frame, TTSAudioRawFrame):
await self._bot_started_speaking()
await self.push_frame(BotSpeakingFrame())
await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM)
# No need to push EndFrame, it's pushed from process_frame().
if isinstance(frame, EndFrame):
break
# No need to push EndFrame, it's pushed from process_frame().
if isinstance(frame, EndFrame):
break
# Handle frame.
await self._sink_frame_handler(frame)
# Handle frame.
await self._sink_frame_handler(frame)
# Also, push frame downstream in case anyone else needs it.
await self.push_frame(frame)
# Also, push frame downstream in case anyone else needs it.
await self.push_frame(frame)
# Send audio.
if isinstance(frame, OutputAudioRawFrame):
await self.write_raw_audio_frames(frame.audio)
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} error writing to microphone: {e}")
# Send audio.
if isinstance(frame, OutputAudioRawFrame):
await self.write_raw_audio_frames(frame.audio)
#
# Camera task
#
def _create_camera_task(self):
loop = self.get_event_loop()
# Create camera output queue and task if needed.
if self._params.camera_out_enabled:
self._camera_out_queue = asyncio.Queue()
self._camera_out_task = loop.create_task(self._camera_out_task_handler())
self._camera_out_task = self.create_task(self._camera_out_task_handler())
async def _cancel_camera_task(self):
# Stop camera output task.
if self._camera_out_task and self._params.camera_out_enabled:
self._camera_out_task.cancel()
await self._camera_out_task
await self.cancel_task(self._camera_out_task)
self._camera_out_task = None
async def _draw_image(self, frame: OutputImageRawFrame):
@@ -387,19 +377,14 @@ class BaseOutputTransport(FrameProcessor):
self._camera_out_frame_duration = 1 / self._params.camera_out_framerate
self._camera_out_frame_reset = self._camera_out_frame_duration * 5
while True:
try:
if self._params.camera_out_is_live:
await self._camera_out_is_live_handler()
elif self._camera_images:
image = next(self._camera_images)
await self._draw_image(image)
await asyncio.sleep(self._camera_out_frame_duration)
else:
await asyncio.sleep(self._camera_out_frame_duration)
except asyncio.CancelledError:
break
except Exception as e:
logger.exception(f"{self} error writing to camera: {e}")
if self._params.camera_out_is_live:
await self._camera_out_is_live_handler()
elif self._camera_images:
image = next(self._camera_images)
await self._draw_image(image)
await asyncio.sleep(self._camera_out_frame_duration)
else:
await asyncio.sleep(self._camera_out_frame_duration)
async def _camera_out_is_live_handler(self):
image = await self._camera_out_queue.get()

View File

@@ -68,11 +68,9 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
async def start(self, frame: StartFrame):
await super().start(frame)
if self._params.session_timeout:
self._monitor_websocket_task = self.get_event_loop().create_task(
self._monitor_websocket()
)
self._monitor_websocket_task = self.create_task(self._monitor_websocket())
await self._callbacks.on_client_connected(self._websocket)
self._receive_task = self.get_event_loop().create_task(self._receive_messages())
self._receive_task = self.create_task(self._receive_messages())
def _iter_data(self) -> typing.AsyncIterator[bytes | str]:
if self._params.serializer.type == FrameSerializerType.BINARY:
@@ -96,11 +94,8 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
async def _monitor_websocket(self):
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
try:
await asyncio.sleep(self._params.session_timeout)
await self._callbacks.on_session_timeout(self._websocket)
except asyncio.CancelledError:
logger.info(f"Monitoring task cancelled for: {self._websocket}")
await asyncio.sleep(self._params.session_timeout)
await self._callbacks.on_session_timeout(self._websocket)
class FastAPIWebsocketOutputTransport(BaseOutputTransport):

View File

@@ -71,7 +71,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
async def start(self, frame: StartFrame):
await super().start(frame)
self._server_task = self.get_event_loop().create_task(self._server_task_handler())
self._server_task = self.create_task(self._server_task_handler())
async def stop(self, frame: EndFrame):
await super().stop(frame)
@@ -131,6 +131,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
await self._callbacks.on_session_timeout(websocket)
except asyncio.CancelledError:
logger.info(f"Monitoring task cancelled for: {websocket.remote_address}")
raise
class WebsocketServerOutputTransport(BaseOutputTransport):

View File

@@ -46,6 +46,7 @@ from pipecat.transcriptions.language import Language
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.utils.utils import cancel_task, create_task
try:
from daily import CallClient, Daily, EventHandler
@@ -218,7 +219,9 @@ class DailyTransportClient(EventHandler):
# future) we will deadlock because completions use event handlers (which
# are holding the GIL).
self._callback_queue = asyncio.Queue()
self._callback_task = self._loop.create_task(self._callback_task_handler())
self._callback_task = create_task(
self._loop, self._callback_task_handler(), "DailyTransportClient::callback_task"
)
self._camera: VirtualCameraDevice | None = None
if self._params.camera_out_enabled:
@@ -469,8 +472,7 @@ class DailyTransportClient(EventHandler):
return await asyncio.wait_for(future, timeout=10)
async def cleanup(self):
self._callback_task.cancel()
await self._callback_task
await cancel_task(self._callback_task)
# Make sure we don't block the event loop in case `client.release()`
# takes extra time.
await self._loop.run_in_executor(self._executor, self._cleanup)
@@ -687,11 +689,8 @@ class DailyTransportClient(EventHandler):
async def _callback_task_handler(self):
while True:
try:
(callback, *args) = await self._callback_queue.get()
await callback(*args)
except asyncio.CancelledError:
break
(callback, *args) = await self._callback_queue.get()
await callback(*args)
class DailyInputTransport(BaseInputTransport):
@@ -721,7 +720,7 @@ class DailyInputTransport(BaseInputTransport):
# Create audio task. It reads audio frames from Daily and push them
# internally for VAD processing.
if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler())
self._audio_in_task = self.create_task(self._audio_in_task_handler())
async def stop(self, frame: EndFrame):
# Parent stop.
@@ -730,8 +729,7 @@ class DailyInputTransport(BaseInputTransport):
await self._client.leave()
# Stop audio thread.
if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled):
self._audio_in_task.cancel()
await self._audio_in_task
await self.cancel_task(self._audio_in_task)
self._audio_in_task = None
async def cancel(self, frame: CancelFrame):
@@ -741,8 +739,7 @@ class DailyInputTransport(BaseInputTransport):
await self._client.leave()
# Stop audio thread.
if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled):
self._audio_in_task.cancel()
await self._audio_in_task
await self.cancel_task(self._audio_in_task)
self._audio_in_task = None
async def cleanup(self):
@@ -779,12 +776,9 @@ class DailyInputTransport(BaseInputTransport):
async def _audio_in_task_handler(self):
while True:
try:
frame = await self._client.read_next_audio_frame()
if frame:
await self.push_audio_frame(frame)
except asyncio.CancelledError:
break
frame = await self._client.read_next_audio_frame()
if frame:
await self.push_audio_frame(frame)
#
# Camera in

View File

@@ -28,6 +28,7 @@ from pipecat.processors.frame_processor import FrameDirection
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.utils.utils import create_task
try:
from livekit import rtc
@@ -215,10 +216,18 @@ class LiveKitTransportClient:
# Wrapper methods for event handlers
def _on_participant_connected_wrapper(self, participant: rtc.RemoteParticipant):
asyncio.create_task(self._async_on_participant_connected(participant))
create_task(
self._loop,
self._async_on_participant_connected(participant),
"LiveKitTransportClient::_async_on_participant_connected",
)
def _on_participant_disconnected_wrapper(self, participant: rtc.RemoteParticipant):
asyncio.create_task(self._async_on_participant_disconnected(participant))
create_task(
self._loop,
self._async_on_participant_disconnected(participant),
"LiveKitTransportClient::_async_on_participant_disconnected",
)
def _on_track_subscribed_wrapper(
self,
@@ -226,7 +235,11 @@ class LiveKitTransportClient:
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
):
asyncio.create_task(self._async_on_track_subscribed(track, publication, participant))
create_task(
self._loop,
self._async_on_track_subscribed(track, publication, participant),
"LiveKitTransportClient::_async_on_track_subscribed",
)
def _on_track_unsubscribed_wrapper(
self,
@@ -234,16 +247,30 @@ class LiveKitTransportClient:
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
):
asyncio.create_task(self._async_on_track_unsubscribed(track, publication, participant))
create_task(
self._loop,
self._async_on_track_unsubscribed(track, publication, participant),
"LiveKitTransportClient::_async_on_track_unsubscribed",
)
def _on_data_received_wrapper(self, data: rtc.DataPacket):
asyncio.create_task(self._async_on_data_received(data))
create_task(
self._loop,
self._async_on_data_received(data),
"LiveKitTransportClient::_async_on_data_received",
)
def _on_connected_wrapper(self):
asyncio.create_task(self._async_on_connected())
create_task(
self._loop, self._async_on_connected(), "LiveKitTransportClient::_async_on_connected"
)
def _on_disconnected_wrapper(self):
asyncio.create_task(self._async_on_disconnected())
create_task(
self._loop,
self._async_on_disconnected(),
"LiveKitTransportClient::_async_on_disconnected",
)
# Async methods for event handling
async def _async_on_participant_connected(self, participant: rtc.RemoteParticipant):
@@ -269,7 +296,11 @@ class LiveKitTransportClient:
logger.info(f"Audio track subscribed: {track.sid} from participant {participant.sid}")
self._audio_tracks[participant.sid] = track
audio_stream = rtc.AudioStream(track)
asyncio.create_task(self._process_audio_stream(audio_stream, participant.sid))
create_task(
self._loop,
self._process_audio_stream(audio_stream, participant.sid),
"LiveKitTransportClient::_process_audio_stream",
)
async def _async_on_track_unsubscribed(
self,
@@ -319,23 +350,21 @@ class LiveKitInputTransport(BaseInputTransport):
await super().start(frame)
await self._client.connect()
if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_in_task = asyncio.create_task(self._audio_in_task_handler())
self._audio_in_task = self.create_task(self._audio_in_task_handler())
logger.info("LiveKitInputTransport started")
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._client.disconnect()
if self._audio_in_task:
self._audio_in_task.cancel()
await self._audio_in_task
await self.cancel_task(self._audio_in_task)
logger.info("LiveKitInputTransport stopped")
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._client.disconnect()
if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled):
self._audio_in_task.cancel()
await self._audio_in_task
await self.cancel_task(self._audio_in_task)
def vad_analyzer(self) -> VADAnalyzer | None:
return self._vad_analyzer
@@ -347,22 +376,16 @@ class LiveKitInputTransport(BaseInputTransport):
async def _audio_in_task_handler(self):
logger.info("Audio input task started")
while True:
try:
audio_data = await self._client.get_next_audio_frame()
if audio_data:
audio_frame_event, participant_id = audio_data
pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event)
input_audio_frame = InputAudioRawFrame(
audio=pipecat_audio_frame.audio,
sample_rate=pipecat_audio_frame.sample_rate,
num_channels=pipecat_audio_frame.num_channels,
)
await self.push_audio_frame(input_audio_frame)
except asyncio.CancelledError:
logger.info("Audio input task cancelled")
break
except Exception as e:
logger.error(f"Error in audio input task: {e}")
audio_data = await self._client.get_next_audio_frame()
if audio_data:
audio_frame_event, participant_id = audio_data
pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event)
input_audio_frame = InputAudioRawFrame(
audio=pipecat_audio_frame.audio,
sample_rate=pipecat_audio_frame.sample_rate,
num_channels=pipecat_audio_frame.num_channels,
)
await self.push_audio_frame(input_audio_frame)
def _convert_livekit_audio_to_pipecat(
self, audio_frame_event: rtc.AudioFrameEvent

View File

@@ -3,8 +3,13 @@
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import collections
import itertools
from typing import Coroutine, Optional
from loguru import logger
_COUNTS = collections.defaultdict(itertools.count)
_ID = itertools.count()
@@ -35,3 +40,37 @@ def obj_count(obj) -> int:
0
"""
return next(_COUNTS[obj.__class__.__name__])
def create_task(loop: asyncio.AbstractEventLoop, coroutine: Coroutine, name: str) -> asyncio.Task:
async def run_coroutine():
try:
await coroutine
except asyncio.CancelledError:
logger.trace(f"{name}: cancelling task")
# Re-raise the exception to ensure the task is cancelled.
raise
except Exception as e:
logger.exception(f"{name}: unexpected exception: {e}")
task = loop.create_task(run_coroutine())
task.set_name(name)
logger.trace(f"{name}: task created")
return task
async def cancel_task(task: asyncio.Task, timeout: Optional[float] = None):
name = task.get_name()
task.cancel()
try:
if timeout:
await asyncio.wait_for(task, timeout=timeout)
else:
await task
except asyncio.TimeoutError:
logger.warning(f"{name}: timed out waiting for task to finish")
except asyncio.CancelledError:
# Here are sure the task is cancelled properly.
logger.trace(f"{name}: task cancelled")
except Exception as e:
logger.exception(f"{name}: unexpected exception while cancelling task: {e}")