From 5ae592f38e37e275dfaef46d5701432aff1df1fb Mon Sep 17 00:00:00 2001 From: Ashot Date: Wed, 7 Jan 2026 15:55:35 +0400 Subject: [PATCH] Improve Async TTS interruption handling by using AudioContextTTSService class and add changelog fragments --- CHANGELOG.md | 9 --- changelog/3287.changed.md | 1 + changelog/3287.fixed.md | 1 + src/pipecat/services/asyncai/tts.py | 120 +--------------------------- 4 files changed, 5 insertions(+), 126 deletions(-) create mode 100644 changelog/3287.changed.md create mode 100644 changelog/3287.fixed.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 63ef32ed1..3d583b4e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,15 +6,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] - -### Changed - -- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management. - -### Fixed - -- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`. ## [0.0.99] - 2026-01-13 diff --git a/changelog/3287.changed.md b/changelog/3287.changed.md new file mode 100644 index 000000000..f0df82966 --- /dev/null +++ b/changelog/3287.changed.md @@ -0,0 +1 @@ +- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management. \ No newline at end of file diff --git a/changelog/3287.fixed.md b/changelog/3287.fixed.md new file mode 100644 index 000000000..30ce0b13b --- /dev/null +++ b/changelog/3287.fixed.md @@ -0,0 +1 @@ +- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`. \ No newline at end of file diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index c49b95153..05ba18654 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -28,7 +28,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import WebsocketTTSService, TTSService +from pipecat.services.tts_service import AudioContextTTSService, WebsocketTTSService, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -73,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]: return resolve_language(language, LANGUAGE_MAP, use_base_code=True) -class AsyncAITTSService(WebsocketTTSService): +class AsyncAITTSService(AudioContextTTSService, WebsocketTTSService): """Async TTS service with WebSocket streaming. Provides text-to-speech using Async's streaming WebSocket API. @@ -154,55 +154,6 @@ class AsyncAITTSService(WebsocketTTSService): self._keepalive_task = None self._started = False - async def create_audio_context(self, context_id: str): - """Create a new audio context for grouping related audio. - - Args: - context_id: Unique identifier for the audio context. - """ - await self._contexts_queue.put(context_id) - self._contexts[context_id] = asyncio.Queue() - logger.trace(f"{self} created audio context {context_id}") - - async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame): - """Append audio to an existing context. - - Args: - context_id: The context to append audio to. - frame: The audio frame to append. - """ - if self.audio_context_available(context_id): - logger.trace(f"{self} appending audio {frame} to audio context {context_id}") - await self._contexts[context_id].put(frame) - else: - logger.warning(f"{self} unable to append audio to context {context_id}") - - async def remove_audio_context(self, context_id: str): - """Remove an existing audio context. - - Args: - context_id: The context to remove. - """ - if self.audio_context_available(context_id): - # We just mark the audio context for deletion by appending - # None. Once we reach None while handling audio we know we can - # safely remove the context. - logger.trace(f"{self} marking audio context {context_id} for deletion") - await self._contexts[context_id].put(None) - else: - logger.warning(f"{self} unable to remove context {context_id}") - - def audio_context_available(self, context_id: str) -> bool: - """Check whether the given audio context is registered. - - Args: - context_id: The context ID to check. - - Returns: - True if the context exists and is available. - """ - return context_id in self._contexts - async def start(self, frame: StartFrame): """Start the Async TTS service. @@ -210,7 +161,6 @@ class AsyncAITTSService(WebsocketTTSService): frame: The start frame containing initialization parameters. """ await super().start(frame) - self._create_audio_context_task() self._settings["output_format"]["sample_rate"] = self.sample_rate await self._connect() @@ -221,12 +171,6 @@ class AsyncAITTSService(WebsocketTTSService): frame: The end frame. """ await super().stop(frame) - if self._audio_context_task: - # Indicate no more audio contexts are available. this will end the - # task cleanly after all contexts have been processed. - await self._contexts_queue.put(None) - await self._audio_context_task - self._audio_context_task = None await self._disconnect() async def cancel(self, frame: CancelFrame): @@ -236,65 +180,7 @@ class AsyncAITTSService(WebsocketTTSService): frame: The cancel frame. """ await super().cancel(frame) - await self._stop_audio_context_task() - await self._disconnect() - - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - await super()._handle_interruption(frame, direction) - await self._stop_audio_context_task() - self._create_audio_context_task() - - def _create_audio_context_task(self): - if not self._audio_context_task: - self._contexts_queue = asyncio.Queue() - self._contexts: Dict[str, asyncio.Queue] = {} - self._audio_context_task = self.create_task(self._audio_context_task_handler()) - - async def _stop_audio_context_task(self): - if self._audio_context_task: - await self.cancel_task(self._audio_context_task) - self._audio_context_task = None - - async def _audio_context_task_handler(self): - """In this task we process audio contexts in order.""" - running = True - while running: - context_id = await self._contexts_queue.get() - - if context_id: - # Process the audio context until the context doesn't have more - # audio available (i.e. we find None). - await self._handle_audio_context(context_id) - - # We just finished processing the context, so we can safely remove it. - del self._contexts[context_id] - - # Append some silence between sentences. - silence = b"\x00" * self.sample_rate - frame = TTSAudioRawFrame( - audio=silence, sample_rate=self.sample_rate, num_channels=1 - ) - await self.push_frame(frame) - else: - running = False - - self._contexts_queue.task_done() - - async def _handle_audio_context(self, context_id: str): - # If we don't receive any audio during this time, we consider the context finished. - AUDIO_CONTEXT_TIMEOUT = 3.0 - queue = self._contexts[context_id] - running = True - while running: - try: - frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT) - if frame: - await self.push_frame(frame) - running = frame is not None - except asyncio.TimeoutError: - # We didn't get audio, so let's consider this context finished. - logger.trace(f"{self} time out on audio context {context_id}") - break + await self._disconnect() def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics.