From 9cdbc56be3b671bda192a4f5529e95eeafb484c4 Mon Sep 17 00:00:00 2001 From: Ashot Date: Tue, 23 Dec 2025 16:35:45 +0400 Subject: [PATCH] Fix TTFB metric and add multi-context WebSocket support for Async TTS --- CHANGELOG.md | 9 + src/pipecat/services/asyncai/tts.py | 323 +++++++++++++++++++++------- 2 files changed, 259 insertions(+), 73 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d583b4e1..63ef32ed1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed + +- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management. + +### Fixed + +- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`. ## [0.0.99] - 2026-01-13 diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index 303369205..c49b95153 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -9,8 +9,9 @@ import asyncio import base64 import json -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, Optional, Dict +import uuid import aiohttp from loguru import logger from pydantic import BaseModel @@ -27,7 +28,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import InterruptibleTTSService, TTSService +from pipecat.services.tts_service import WebsocketTTSService, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -72,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]: return resolve_language(language, LANGUAGE_MAP, use_base_code=True) -class AsyncAITTSService(InterruptibleTTSService): +class AsyncAITTSService(WebsocketTTSService): """Async TTS service with WebSocket streaming. Provides text-to-speech using Async's streaming WebSocket API. @@ -126,6 +127,10 @@ class AsyncAITTSService(InterruptibleTTSService): **kwargs, ) + self._contexts: Dict[str, asyncio.Queue] = {} + self._audio_context_task = None + self._context_id = None + params = params or AsyncAITTSService.InputParams() self._api_key = api_key @@ -148,6 +153,148 @@ class AsyncAITTSService(InterruptibleTTSService): self._receive_task = None self._keepalive_task = None self._started = False + + async def create_audio_context(self, context_id: str): + """Create a new audio context for grouping related audio. + + Args: + context_id: Unique identifier for the audio context. + """ + await self._contexts_queue.put(context_id) + self._contexts[context_id] = asyncio.Queue() + logger.trace(f"{self} created audio context {context_id}") + + async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame): + """Append audio to an existing context. + + Args: + context_id: The context to append audio to. + frame: The audio frame to append. + """ + if self.audio_context_available(context_id): + logger.trace(f"{self} appending audio {frame} to audio context {context_id}") + await self._contexts[context_id].put(frame) + else: + logger.warning(f"{self} unable to append audio to context {context_id}") + + async def remove_audio_context(self, context_id: str): + """Remove an existing audio context. + + Args: + context_id: The context to remove. + """ + if self.audio_context_available(context_id): + # We just mark the audio context for deletion by appending + # None. Once we reach None while handling audio we know we can + # safely remove the context. + logger.trace(f"{self} marking audio context {context_id} for deletion") + await self._contexts[context_id].put(None) + else: + logger.warning(f"{self} unable to remove context {context_id}") + + def audio_context_available(self, context_id: str) -> bool: + """Check whether the given audio context is registered. + + Args: + context_id: The context ID to check. + + Returns: + True if the context exists and is available. + """ + return context_id in self._contexts + + async def start(self, frame: StartFrame): + """Start the Async TTS service. + + Args: + frame: The start frame containing initialization parameters. + """ + await super().start(frame) + self._create_audio_context_task() + self._settings["output_format"]["sample_rate"] = self.sample_rate + await self._connect() + + async def stop(self, frame: EndFrame): + """Stop the Async TTS service. + + Args: + frame: The end frame. + """ + await super().stop(frame) + if self._audio_context_task: + # Indicate no more audio contexts are available. this will end the + # task cleanly after all contexts have been processed. + await self._contexts_queue.put(None) + await self._audio_context_task + self._audio_context_task = None + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + """Cancel the Async TTS service. + + Args: + frame: The cancel frame. + """ + await super().cancel(frame) + await self._stop_audio_context_task() + await self._disconnect() + + async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + await self._stop_audio_context_task() + self._create_audio_context_task() + + def _create_audio_context_task(self): + if not self._audio_context_task: + self._contexts_queue = asyncio.Queue() + self._contexts: Dict[str, asyncio.Queue] = {} + self._audio_context_task = self.create_task(self._audio_context_task_handler()) + + async def _stop_audio_context_task(self): + if self._audio_context_task: + await self.cancel_task(self._audio_context_task) + self._audio_context_task = None + + async def _audio_context_task_handler(self): + """In this task we process audio contexts in order.""" + running = True + while running: + context_id = await self._contexts_queue.get() + + if context_id: + # Process the audio context until the context doesn't have more + # audio available (i.e. we find None). + await self._handle_audio_context(context_id) + + # We just finished processing the context, so we can safely remove it. + del self._contexts[context_id] + + # Append some silence between sentences. + silence = b"\x00" * self.sample_rate + frame = TTSAudioRawFrame( + audio=silence, sample_rate=self.sample_rate, num_channels=1 + ) + await self.push_frame(frame) + else: + running = False + + self._contexts_queue.task_done() + + async def _handle_audio_context(self, context_id: str): + # If we don't receive any audio during this time, we consider the context finished. + AUDIO_CONTEXT_TIMEOUT = 3.0 + queue = self._contexts[context_id] + running = True + while running: + try: + frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT) + if frame: + await self.push_frame(frame) + running = frame is not None + except asyncio.TimeoutError: + # We didn't get audio, so let's consider this context finished. + logger.trace(f"{self} time out on audio context {context_id}") + break def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. @@ -168,38 +315,10 @@ class AsyncAITTSService(InterruptibleTTSService): """ return language_to_async_language(language) - def _build_msg(self, text: str = "", force: bool = False) -> str: - msg = {"transcript": text, "force": force} + def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str: + msg = {"transcript": text, "context_id": context_id, "force": force} return json.dumps(msg) - async def start(self, frame: StartFrame): - """Start the Async TTS service. - - Args: - frame: The start frame containing initialization parameters. - """ - await super().start(frame) - self._settings["output_format"]["sample_rate"] = self.sample_rate - await self._connect() - - async def stop(self, frame: EndFrame): - """Stop the Async TTS service. - - Args: - frame: The end frame. - """ - await super().stop(frame) - await self._disconnect() - - async def cancel(self, frame: CancelFrame): - """Cancel the Async TTS service. - - Args: - frame: The cancel frame. - """ - await super().cancel(frame) - await self._disconnect() - async def _connect(self): await super()._connect() @@ -253,11 +372,16 @@ class AsyncAITTSService(InterruptibleTTSService): if self._websocket: logger.debug("Disconnecting from Async") + # Close all contexts and the socket + if self._context_id: + await self._websocket.send(json.dumps({"terminate": True})) await self._websocket.close() + logger.debug("Disconnected from Async") except Exception as e: - await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) + logger.error(f"{self} error closing websocket: {e}") finally: self._websocket = None + self._context_id = None self._started = False await self._call_event_handler("on_disconnected") @@ -268,10 +392,10 @@ class AsyncAITTSService(InterruptibleTTSService): async def flush_audio(self): """Flush any pending audio.""" - if not self._websocket: + if not self._context_id or not self._websocket: return logger.trace(f"{self}: flushing audio") - msg = self._build_msg(text=" ", force=True) + msg = self._build_msg(text=" ", context_id=self._context_id, force=True) await self._websocket.send(msg) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): @@ -291,35 +415,70 @@ class AsyncAITTSService(InterruptibleTTSService): if not msg: continue - elif msg.get("audio"): + received_ctx_id = msg.get("context_id") + # Handle final messages first, regardless of context availability + # At the moment, this message is received AFTER the close_context message is + # sent, so it doesn't serve any functional purpose. For now, we'll just log it. + if msg.get("final") is True: + logger.trace(f"Received final message for context {received_ctx_id}") + continue + + # Check if this message belongs to the current context. + if not self.audio_context_available(received_ctx_id): + if self._context_id == received_ctx_id: + logger.debug( + f"Received a delayed message, recreating the context: {self._context_id}" + ) + await self.create_audio_context(self._context_id) + else: + # This can happen if a message is received _after_ we have closed a context + # due to user interruption but _before_ the `isFinal` message for the context + # is received. + logger.debug(f"Ignoring message from unavailable context: {received_ctx_id}") + continue + + if msg.get("audio"): await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( - audio=base64.b64decode(msg["audio"]), - sample_rate=self.sample_rate, - num_channels=1, - ) - await self.push_frame(frame) - elif msg.get("error_code"): - await self.push_frame(TTSStoppedFrame()) - await self.stop_all_metrics() - await self.push_error(error_msg=f"Error: {msg['message']}") - else: - await self.push_error(error_msg=f"Unknown message type: {msg}") + audio = base64.b64decode(msg["audio"]) + frame = TTSAudioRawFrame(audio, self.sample_rate, 1) + await self.append_to_audio_context(received_ctx_id, frame) async def _keepalive_task_handler(self): """Send periodic keepalive messages to maintain WebSocket connection.""" - KEEPALIVE_SLEEP = 3 + KEEPALIVE_SLEEP = 10 while True: await asyncio.sleep(KEEPALIVE_SLEEP) try: if self._websocket and self._websocket.state is State.OPEN: - keepalive_message = {"transcript": " "} - logger.trace("Sending keepalive message") + if self._context_id: + keepalive_message = {"transcript": " ", "context_id": self._context_id,} + logger.trace("Sending keepalive message") + else: + # It's possible to have a user interruption which clears the context + # without generating a new TTS response. In this case, we'll just send + # an empty message to keep the connection alive. + keepalive_message = {"transcript": " "} + logger.trace("Sending keepalive without context") await self._websocket.send(json.dumps(keepalive_message)) except websockets.ConnectionClosed as e: logger.warning(f"{self} keepalive error: {e}") break + async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): + """Handle interruption by closing the current context.""" + await super()._handle_interruption(frame, direction) + + # Close the current context when interrupted without closing the websocket + if self._context_id and self._websocket: + try: + await self._websocket.send( + json.dumps({"context_id": self._context_id, "close_context": True, "transcript": ""}) + ) + except Exception as e: + logger.error(f"Error closing context on interruption: {e}") + self._context_id = None + self._started = False + @traced_tts async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: """Generate speech from text using Async API websocket endpoint. @@ -336,26 +495,35 @@ class AsyncAITTSService(InterruptibleTTSService): if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() - if not self._started: - await self.start_ttfb_metrics() - yield TTSStartedFrame() - self._started = True + try: + if not self._started: + await self.start_ttfb_metrics() + yield TTSStartedFrame() + self._started = True - msg = self._build_msg(text=text, force=True) + if not self._context_id: + self._context_id = str(uuid.uuid4()) + if not self.audio_context_available(self._context_id): + await self.create_audio_context(self._context_id) - try: - await self._get_websocket().send(msg) - await self.start_tts_usage_metrics(text) + msg = self._build_msg(text=" ", context_id=self._context_id) + await self._get_websocket().send(msg) + msg = self._build_msg(text=text, force=True, context_id=self._context_id) + await self._get_websocket().send(msg) + await self.start_tts_usage_metrics(text) + else: + if self._websocket and self._context_id: + msg = self._build_msg(text=text, force=True, context_id=self._context_id) + await self._get_websocket().send(msg) + except Exception as e: - yield ErrorFrame(error=f"Unknown error occurred: {e}") + logger.error(f"{self} error sending message: {e}") yield TTSStoppedFrame() - await self._disconnect() - await self._connect() + self._started = False return yield None except Exception as e: - yield ErrorFrame(error=f"Unknown error occurred: {e}") - + logger.error(f"{self} exception: {e}") class AsyncAIHttpTTSService(TTSService): """HTTP-based Async TTS service. @@ -466,9 +634,9 @@ class AsyncAIHttpTTSService(TTSService): """ logger.debug(f"{self}: Generating TTS [{text}]") + first_byte_seen = False try: voice_config = {"mode": "id", "id": self._voice_id} - await self.start_ttfb_metrics() payload = { "model_id": self._model_name, "transcript": text, @@ -476,7 +644,6 @@ class AsyncAIHttpTTSService(TTSService): "output_format": self._settings["output_format"], "language": self._settings["language"], } - yield TTSStartedFrame() headers = { "version": self._api_version, "x-api-key": self._api_key, @@ -484,26 +651,36 @@ class AsyncAIHttpTTSService(TTSService): } url = f"{self._base_url}/text_to_speech/streaming" + yield TTSStartedFrame() + await self.start_ttfb_metrics() async with self._session.post(url, json=payload, headers=headers) as response: if response.status != 200: error_text = await response.text() await self.push_error(error_msg=f"Async API error: {error_text}") raise Exception(f"Async API returned status {response.status}: {error_text}") - audio_data = await response.read() + # Read streaming bytes; stop TTFB on the *first* received chunk + buffer = bytearray() + async for chunk in response.content.iter_chunked(64 * 1024): + if not chunk: + continue + if not first_byte_seen: + first_byte_seen = True + await self.stop_ttfb_metrics() + await self.start_tts_usage_metrics(text) - await self.start_tts_usage_metrics(text) + buffer.extend(chunk) + audio_data = bytes(buffer) - frame = TTSAudioRawFrame( + yield TTSAudioRawFrame( audio=audio_data, sample_rate=self.sample_rate, num_channels=1, ) - yield frame - except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: - await self.stop_ttfb_metrics() + if not first_byte_seen: + await self.stop_ttfb_metrics() yield TTSStoppedFrame()