Adjustments of Async TTS for multicontext websocket support
This commit is contained in:
@@ -9,9 +9,9 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator, Optional, Dict
|
||||
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
@@ -127,10 +127,6 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._contexts: Dict[str, asyncio.Queue] = {}
|
||||
self._audio_context_task = None
|
||||
self._context_id = None
|
||||
|
||||
params = params or AsyncAITTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
@@ -153,6 +149,30 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Async service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Async language format.
|
||||
|
||||
Args:
|
||||
language: The language to convert.
|
||||
|
||||
Returns:
|
||||
The Async-specific language code, or None if not supported.
|
||||
"""
|
||||
return language_to_async_language(language)
|
||||
|
||||
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
|
||||
msg = {"transcript": text, "context_id": context_id, "force": force}
|
||||
return json.dumps(msg)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Async TTS service.
|
||||
@@ -182,29 +202,6 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Async service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Async language format.
|
||||
|
||||
Args:
|
||||
language: The language to convert.
|
||||
|
||||
Returns:
|
||||
The Async-specific language code, or None if not supported.
|
||||
"""
|
||||
return language_to_async_language(language)
|
||||
|
||||
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
|
||||
msg = {"transcript": text, "context_id": context_id, "force": force}
|
||||
return json.dumps(msg)
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
@@ -264,7 +261,7 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from Async")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._context_id = None
|
||||
@@ -338,7 +335,7 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
if self._context_id:
|
||||
keepalive_message = {
|
||||
"transcript": " ",
|
||||
"transcript": " ",
|
||||
"context_id": self._context_id,
|
||||
}
|
||||
logger.trace("Sending keepalive message")
|
||||
@@ -397,24 +394,22 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
if not self.audio_context_available(self._context_id):
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
msg = self._build_msg(text=" ", context_id=self._context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
else:
|
||||
if self._websocket and self._context_id:
|
||||
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
await self._get_websocket().send(msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
self._started = False
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
|
||||
class AsyncAIHttpTTSService(TTSService):
|
||||
@@ -526,9 +521,9 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
first_byte_seen = False
|
||||
try:
|
||||
voice_config = {"mode": "id", "id": self._voice_id}
|
||||
await self.start_ttfb_metrics()
|
||||
payload = {
|
||||
"model_id": self._model_name,
|
||||
"transcript": text,
|
||||
@@ -536,6 +531,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
"output_format": self._settings["output_format"],
|
||||
"language": self._settings["language"],
|
||||
}
|
||||
yield TTSStartedFrame()
|
||||
headers = {
|
||||
"version": self._api_version,
|
||||
"x-api-key": self._api_key,
|
||||
@@ -543,8 +539,6 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
}
|
||||
url = f"{self._base_url}/text_to_speech/streaming"
|
||||
|
||||
yield TTSStartedFrame()
|
||||
await self.start_ttfb_metrics()
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
@@ -556,23 +550,22 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
async for chunk in response.content.iter_chunked(64 * 1024):
|
||||
if not chunk:
|
||||
continue
|
||||
if not first_byte_seen:
|
||||
first_byte_seen = True
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
buffer.extend(chunk)
|
||||
audio_data = bytes(buffer)
|
||||
|
||||
yield TTSAudioRawFrame(
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=audio_data,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
if not first_byte_seen:
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
Reference in New Issue
Block a user