From c4ae4025f3bc28d8a65e8ffc118fda7104500f77 Mon Sep 17 00:00:00 2001 From: Ashot Date: Wed, 14 Jan 2026 16:33:30 +0400 Subject: [PATCH] Adjustments of Async TTS for multicontext websocket support --- src/pipecat/services/asyncai/tts.py | 87 +++++++++++++---------------- 1 file changed, 40 insertions(+), 47 deletions(-) diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index 04a847955..fbc760562 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -9,9 +9,9 @@ import asyncio import base64 import json -from typing import AsyncGenerator, Optional, Dict - import uuid +from typing import AsyncGenerator, Optional + import aiohttp from loguru import logger from pydantic import BaseModel @@ -127,10 +127,6 @@ class AsyncAITTSService(AudioContextTTSService): **kwargs, ) - self._contexts: Dict[str, asyncio.Queue] = {} - self._audio_context_task = None - self._context_id = None - params = params or AsyncAITTSService.InputParams() self._api_key = api_key @@ -153,6 +149,30 @@ class AsyncAITTSService(AudioContextTTSService): self._receive_task = None self._keepalive_task = None self._started = False + self._context_id = None + + def can_generate_metrics(self) -> bool: + """Check if this service can generate processing metrics. + + Returns: + True, as Async service supports metrics generation. + """ + return True + + def language_to_service_language(self, language: Language) -> Optional[str]: + """Convert a Language enum to Async language format. + + Args: + language: The language to convert. + + Returns: + The Async-specific language code, or None if not supported. + """ + return language_to_async_language(language) + + def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str: + msg = {"transcript": text, "context_id": context_id, "force": force} + return json.dumps(msg) async def start(self, frame: StartFrame): """Start the Async TTS service. @@ -182,29 +202,6 @@ class AsyncAITTSService(AudioContextTTSService): await super().cancel(frame) await self._disconnect() - def can_generate_metrics(self) -> bool: - """Check if this service can generate processing metrics. - - Returns: - True, as Async service supports metrics generation. - """ - return True - - def language_to_service_language(self, language: Language) -> Optional[str]: - """Convert a Language enum to Async language format. - - Args: - language: The language to convert. - - Returns: - The Async-specific language code, or None if not supported. - """ - return language_to_async_language(language) - - def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str: - msg = {"transcript": text, "context_id": context_id, "force": force} - return json.dumps(msg) - async def _connect(self): await super()._connect() @@ -264,7 +261,7 @@ class AsyncAITTSService(AudioContextTTSService): await self._websocket.close() logger.debug("Disconnected from Async") except Exception as e: - logger.error(f"{self} error closing websocket: {e}") + await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: self._websocket = None self._context_id = None @@ -338,7 +335,7 @@ class AsyncAITTSService(AudioContextTTSService): if self._websocket and self._websocket.state is State.OPEN: if self._context_id: keepalive_message = { - "transcript": " ", + "transcript": " ", "context_id": self._context_id, } logger.trace("Sending keepalive message") @@ -397,24 +394,22 @@ class AsyncAITTSService(AudioContextTTSService): if not self.audio_context_available(self._context_id): await self.create_audio_context(self._context_id) - msg = self._build_msg(text=" ", context_id=self._context_id) - await self._get_websocket().send(msg) msg = self._build_msg(text=text, force=True, context_id=self._context_id) await self._get_websocket().send(msg) await self.start_tts_usage_metrics(text) else: if self._websocket and self._context_id: msg = self._build_msg(text=text, force=True, context_id=self._context_id) - await self._get_websocket().send(msg) + await self._get_websocket().send(msg) except Exception as e: - logger.error(f"{self} error sending message: {e}") + yield ErrorFrame(error=f"Unknown error occurred: {e}") yield TTSStoppedFrame() self._started = False return yield None except Exception as e: - logger.error(f"{self} exception: {e}") + yield ErrorFrame(error=f"Unknown error occurred: {e}") class AsyncAIHttpTTSService(TTSService): @@ -526,9 +521,9 @@ class AsyncAIHttpTTSService(TTSService): """ logger.debug(f"{self}: Generating TTS [{text}]") - first_byte_seen = False try: voice_config = {"mode": "id", "id": self._voice_id} + await self.start_ttfb_metrics() payload = { "model_id": self._model_name, "transcript": text, @@ -536,6 +531,7 @@ class AsyncAIHttpTTSService(TTSService): "output_format": self._settings["output_format"], "language": self._settings["language"], } + yield TTSStartedFrame() headers = { "version": self._api_version, "x-api-key": self._api_key, @@ -543,8 +539,6 @@ class AsyncAIHttpTTSService(TTSService): } url = f"{self._base_url}/text_to_speech/streaming" - yield TTSStartedFrame() - await self.start_ttfb_metrics() async with self._session.post(url, json=payload, headers=headers) as response: if response.status != 200: error_text = await response.text() @@ -556,23 +550,22 @@ class AsyncAIHttpTTSService(TTSService): async for chunk in response.content.iter_chunked(64 * 1024): if not chunk: continue - if not first_byte_seen: - first_byte_seen = True - await self.stop_ttfb_metrics() - await self.start_tts_usage_metrics(text) - + await self.stop_ttfb_metrics() buffer.extend(chunk) audio_data = bytes(buffer) - yield TTSAudioRawFrame( + await self.start_tts_usage_metrics(text) + + frame = TTSAudioRawFrame( audio=audio_data, sample_rate=self.sample_rate, num_channels=1, ) + yield frame + except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: - if not first_byte_seen: - await self.stop_ttfb_metrics() + await self.stop_ttfb_metrics() yield TTSStoppedFrame()