Adjustments of Async TTS for multicontext websocket support

This commit is contained in:
Ashot
2026-01-14 16:33:30 +04:00
parent 15067c678d
commit c4ae4025f3

View File

@@ -9,9 +9,9 @@
import asyncio
import base64
import json
from typing import AsyncGenerator, Optional, Dict
import uuid
from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
from pydantic import BaseModel
@@ -127,10 +127,6 @@ class AsyncAITTSService(AudioContextTTSService):
**kwargs,
)
self._contexts: Dict[str, asyncio.Queue] = {}
self._audio_context_task = None
self._context_id = None
params = params or AsyncAITTSService.InputParams()
self._api_key = api_key
@@ -153,6 +149,30 @@ class AsyncAITTSService(AudioContextTTSService):
self._receive_task = None
self._keepalive_task = None
self._started = False
self._context_id = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Async service supports metrics generation.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Async language format.
Args:
language: The language to convert.
Returns:
The Async-specific language code, or None if not supported.
"""
return language_to_async_language(language)
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
msg = {"transcript": text, "context_id": context_id, "force": force}
return json.dumps(msg)
async def start(self, frame: StartFrame):
"""Start the Async TTS service.
@@ -182,29 +202,6 @@ class AsyncAITTSService(AudioContextTTSService):
await super().cancel(frame)
await self._disconnect()
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Async service supports metrics generation.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Async language format.
Args:
language: The language to convert.
Returns:
The Async-specific language code, or None if not supported.
"""
return language_to_async_language(language)
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
msg = {"transcript": text, "context_id": context_id, "force": force}
return json.dumps(msg)
async def _connect(self):
await super()._connect()
@@ -264,7 +261,7 @@ class AsyncAITTSService(AudioContextTTSService):
await self._websocket.close()
logger.debug("Disconnected from Async")
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
self._context_id = None
@@ -338,7 +335,7 @@ class AsyncAITTSService(AudioContextTTSService):
if self._websocket and self._websocket.state is State.OPEN:
if self._context_id:
keepalive_message = {
"transcript": " ",
"transcript": " ",
"context_id": self._context_id,
}
logger.trace("Sending keepalive message")
@@ -397,24 +394,22 @@ class AsyncAITTSService(AudioContextTTSService):
if not self.audio_context_available(self._context_id):
await self.create_audio_context(self._context_id)
msg = self._build_msg(text=" ", context_id=self._context_id)
await self._get_websocket().send(msg)
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
else:
if self._websocket and self._context_id:
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
await self._get_websocket().send(msg)
await self._get_websocket().send(msg)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
self._started = False
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class AsyncAIHttpTTSService(TTSService):
@@ -526,9 +521,9 @@ class AsyncAIHttpTTSService(TTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
first_byte_seen = False
try:
voice_config = {"mode": "id", "id": self._voice_id}
await self.start_ttfb_metrics()
payload = {
"model_id": self._model_name,
"transcript": text,
@@ -536,6 +531,7 @@ class AsyncAIHttpTTSService(TTSService):
"output_format": self._settings["output_format"],
"language": self._settings["language"],
}
yield TTSStartedFrame()
headers = {
"version": self._api_version,
"x-api-key": self._api_key,
@@ -543,8 +539,6 @@ class AsyncAIHttpTTSService(TTSService):
}
url = f"{self._base_url}/text_to_speech/streaming"
yield TTSStartedFrame()
await self.start_ttfb_metrics()
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
@@ -556,23 +550,22 @@ class AsyncAIHttpTTSService(TTSService):
async for chunk in response.content.iter_chunked(64 * 1024):
if not chunk:
continue
if not first_byte_seen:
first_byte_seen = True
await self.stop_ttfb_metrics()
await self.start_tts_usage_metrics(text)
await self.stop_ttfb_metrics()
buffer.extend(chunk)
audio_data = bytes(buffer)
yield TTSAudioRawFrame(
await self.start_tts_usage_metrics(text)
frame = TTSAudioRawFrame(
audio=audio_data,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
if not first_byte_seen:
await self.stop_ttfb_metrics()
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()