353 lines
13 KiB
Python
353 lines
13 KiB
Python
"""DashScope realtime TTS service.
|
|
|
|
Implements DashScope's Qwen realtime TTS protocol via the official SDK.
|
|
"""
|
|
|
|
import asyncio
|
|
import audioop
|
|
import base64
|
|
import json
|
|
import os
|
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
|
|
|
from loguru import logger
|
|
|
|
from services.base import BaseTTSService, ServiceState, TTSChunk
|
|
|
|
try:
|
|
import dashscope
|
|
from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback
|
|
|
|
DASHSCOPE_SDK_AVAILABLE = True
|
|
except ImportError:
|
|
dashscope = None # type: ignore[assignment]
|
|
AudioFormat = None # type: ignore[assignment]
|
|
QwenTtsRealtime = None # type: ignore[assignment]
|
|
DASHSCOPE_SDK_AVAILABLE = False
|
|
|
|
class QwenTtsRealtimeCallback: # type: ignore[no-redef]
|
|
"""Fallback callback base when DashScope SDK is unavailable."""
|
|
|
|
pass
|
|
|
|
|
|
class _RealtimeEventCallback(QwenTtsRealtimeCallback):
|
|
"""Bridge SDK callback events into an asyncio queue."""
|
|
|
|
def __init__(self, loop: asyncio.AbstractEventLoop, queue: "asyncio.Queue[Dict[str, Any]]"):
|
|
super().__init__()
|
|
self._loop = loop
|
|
self._queue = queue
|
|
|
|
def _push(self, event: Dict[str, Any]) -> None:
|
|
try:
|
|
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
|
except RuntimeError:
|
|
return
|
|
|
|
def on_open(self) -> None:
|
|
self._push({"type": "session.open"})
|
|
|
|
def on_close(self, code: int, reason: str) -> None:
|
|
self._push({"type": "__close__", "code": code, "reason": reason})
|
|
|
|
def on_error(self, message: str) -> None:
|
|
self._push({"type": "error", "error": {"message": str(message)}})
|
|
|
|
def on_event(self, event: Any) -> None:
|
|
if isinstance(event, dict):
|
|
payload = event
|
|
elif isinstance(event, str):
|
|
try:
|
|
payload = json.loads(event)
|
|
except json.JSONDecodeError:
|
|
payload = {"type": "raw", "message": event}
|
|
else:
|
|
payload = {"type": "raw", "message": str(event)}
|
|
self._push(payload)
|
|
|
|
def on_data(self, data: bytes) -> None:
|
|
# Some SDK versions provide audio via on_data directly.
|
|
self._push({"type": "response.audio.delta.raw", "audio": data})
|
|
|
|
|
|
class DashScopeTTSService(BaseTTSService):
|
|
"""DashScope realtime TTS service using Qwen Realtime protocol."""
|
|
|
|
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
|
DEFAULT_MODEL = "qwen3-tts-flash-realtime"
|
|
PROVIDER_SAMPLE_RATE = 24000
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
api_url: Optional[str] = None,
|
|
voice: str = "Cherry",
|
|
model: Optional[str] = None,
|
|
mode: str = "server_commit",
|
|
sample_rate: int = 16000,
|
|
speed: float = 1.0,
|
|
):
|
|
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
|
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY") or os.getenv("TTS_API_KEY")
|
|
self.api_url = (
|
|
api_url
|
|
or os.getenv("DASHSCOPE_TTS_API_URL")
|
|
or os.getenv("TTS_API_URL")
|
|
or self.DEFAULT_WS_URL
|
|
)
|
|
self.model = model or os.getenv("DASHSCOPE_TTS_MODEL") or self.DEFAULT_MODEL
|
|
|
|
normalized_mode = str(mode or "").strip().lower()
|
|
if normalized_mode not in {"server_commit", "commit"}:
|
|
logger.warning(f"Unknown DashScope mode '{mode}', fallback to server_commit")
|
|
normalized_mode = "server_commit"
|
|
self.mode = normalized_mode
|
|
|
|
self._client: Optional[Any] = None
|
|
self._event_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
|
self._callback: Optional[_RealtimeEventCallback] = None
|
|
self._cancel_event = asyncio.Event()
|
|
self._synthesis_lock = asyncio.Lock()
|
|
|
|
async def connect(self) -> None:
|
|
if not DASHSCOPE_SDK_AVAILABLE:
|
|
raise RuntimeError("dashscope package not installed; install with `pip install dashscope`")
|
|
if not self.api_key:
|
|
raise ValueError("DashScope API key not provided. Configure agent.tts.api_key in YAML.")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
self._callback = _RealtimeEventCallback(loop=loop, queue=self._event_queue)
|
|
# The official Python SDK docs set key via global `dashscope.api_key`;
|
|
# some SDK versions do not accept `api_key=` in QwenTtsRealtime ctor.
|
|
if dashscope is not None:
|
|
dashscope.api_key = self.api_key
|
|
self._client = self._create_realtime_client(self._callback)
|
|
|
|
await asyncio.to_thread(self._client.connect)
|
|
await asyncio.to_thread(
|
|
self._client.update_session,
|
|
voice=self.voice,
|
|
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
|
|
mode=self.mode,
|
|
)
|
|
await self._wait_for_session_ready()
|
|
|
|
self.state = ServiceState.CONNECTED
|
|
logger.info(
|
|
"DashScope realtime TTS service ready: "
|
|
f"voice={self.voice}, model={self.model}, mode={self.mode}"
|
|
)
|
|
|
|
def _create_realtime_client(self, callback: _RealtimeEventCallback) -> Any:
|
|
init_kwargs = {
|
|
"model": self.model,
|
|
"callback": callback,
|
|
"url": self.api_url,
|
|
}
|
|
try:
|
|
return QwenTtsRealtime( # type: ignore[misc]
|
|
api_key=self.api_key,
|
|
**init_kwargs,
|
|
)
|
|
except TypeError as exc:
|
|
if "api_key" not in str(exc):
|
|
raise
|
|
logger.debug(
|
|
"QwenTtsRealtime does not support `api_key` ctor arg; "
|
|
"falling back to global dashscope.api_key auth"
|
|
)
|
|
return QwenTtsRealtime(**init_kwargs) # type: ignore[misc]
|
|
|
|
async def disconnect(self) -> None:
|
|
self._cancel_event.set()
|
|
if self._client:
|
|
close_fn = getattr(self._client, "close", None)
|
|
if callable(close_fn):
|
|
await asyncio.to_thread(close_fn)
|
|
self._client = None
|
|
self._drain_event_queue()
|
|
self.state = ServiceState.DISCONNECTED
|
|
logger.info("DashScope realtime TTS service disconnected")
|
|
|
|
async def synthesize(self, text: str) -> bytes:
|
|
audio = b""
|
|
async for chunk in self.synthesize_stream(text):
|
|
audio += chunk.audio
|
|
return audio
|
|
|
|
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
|
if not self._client:
|
|
raise RuntimeError("DashScope TTS service not connected")
|
|
if not text.strip():
|
|
return
|
|
|
|
async with self._synthesis_lock:
|
|
self._cancel_event.clear()
|
|
self._drain_event_queue()
|
|
|
|
await self._clear_appended_text()
|
|
await asyncio.to_thread(self._client.append_text, text)
|
|
if self.mode == "commit":
|
|
await asyncio.to_thread(self._client.commit)
|
|
|
|
chunk_size = max(1, self.sample_rate * 2 // 10) # 100ms
|
|
buffer = b""
|
|
pending_chunk: Optional[bytes] = None
|
|
resample_state: Any = None
|
|
|
|
while True:
|
|
timeout = 8.0 if self._cancel_event.is_set() else 20.0
|
|
event = await self._next_event(timeout=timeout)
|
|
event_type = str(event.get("type") or "").strip()
|
|
|
|
if event_type in {"response.audio.delta", "response.audio.delta.raw"}:
|
|
if self._cancel_event.is_set():
|
|
continue
|
|
|
|
pcm = self._decode_audio_event(event)
|
|
if not pcm:
|
|
continue
|
|
pcm, resample_state = self._resample_if_needed(pcm, resample_state)
|
|
if not pcm:
|
|
continue
|
|
|
|
buffer += pcm
|
|
while len(buffer) >= chunk_size:
|
|
audio_chunk = buffer[:chunk_size]
|
|
buffer = buffer[chunk_size:]
|
|
if pending_chunk is not None:
|
|
yield TTSChunk(
|
|
audio=pending_chunk,
|
|
sample_rate=self.sample_rate,
|
|
is_final=False,
|
|
)
|
|
pending_chunk = audio_chunk
|
|
continue
|
|
|
|
if event_type == "response.done":
|
|
break
|
|
|
|
if event_type == "error":
|
|
raise RuntimeError(self._format_error_event(event))
|
|
|
|
if event_type == "__close__":
|
|
reason = str(event.get("reason") or "unknown")
|
|
raise RuntimeError(f"DashScope TTS websocket closed unexpectedly: {reason}")
|
|
|
|
if self._cancel_event.is_set():
|
|
return
|
|
|
|
if pending_chunk is not None:
|
|
if buffer:
|
|
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
|
pending_chunk = None
|
|
else:
|
|
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
|
|
pending_chunk = None
|
|
|
|
if buffer:
|
|
yield TTSChunk(audio=buffer, sample_rate=self.sample_rate, is_final=True)
|
|
|
|
async def cancel(self) -> None:
|
|
self._cancel_event.set()
|
|
if self.mode == "commit":
|
|
await self._clear_appended_text()
|
|
return
|
|
|
|
if not self._client:
|
|
return
|
|
cancel_fn = (
|
|
getattr(self._client, "cancel_response", None)
|
|
or getattr(self._client, "cancel", None)
|
|
)
|
|
if callable(cancel_fn):
|
|
try:
|
|
await asyncio.to_thread(cancel_fn)
|
|
except Exception as exc:
|
|
logger.debug(f"DashScope cancel failed: {exc}")
|
|
|
|
async def _wait_for_session_ready(self) -> None:
|
|
try:
|
|
while True:
|
|
event = await self._next_event(timeout=8.0)
|
|
event_type = str(event.get("type") or "").strip()
|
|
if event_type in {"session.updated", "session.open"}:
|
|
return
|
|
if event_type == "error":
|
|
raise RuntimeError(self._format_error_event(event))
|
|
except asyncio.TimeoutError:
|
|
logger.debug("DashScope session update event timeout; continuing with active websocket")
|
|
|
|
async def _clear_appended_text(self) -> None:
|
|
if self.mode != "commit":
|
|
return
|
|
if not self._client:
|
|
return
|
|
clear_fn = getattr(self._client, "clear_appended_text", None)
|
|
if callable(clear_fn):
|
|
try:
|
|
await asyncio.to_thread(clear_fn)
|
|
except Exception as exc:
|
|
logger.debug(f"DashScope clear_appended_text failed: {exc}")
|
|
|
|
async def _next_event(self, timeout: float) -> Dict[str, Any]:
|
|
event = await asyncio.wait_for(self._event_queue.get(), timeout=timeout)
|
|
if isinstance(event, dict):
|
|
return event
|
|
return {"type": "raw", "message": str(event)}
|
|
|
|
def _drain_event_queue(self) -> None:
|
|
while True:
|
|
try:
|
|
self._event_queue.get_nowait()
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
|
|
def _decode_audio_event(self, event: Dict[str, Any]) -> bytes:
|
|
event_type = str(event.get("type") or "")
|
|
if event_type == "response.audio.delta.raw":
|
|
audio = event.get("audio")
|
|
if isinstance(audio, (bytes, bytearray)):
|
|
return bytes(audio)
|
|
return b""
|
|
|
|
delta = event.get("delta")
|
|
if isinstance(delta, str):
|
|
try:
|
|
return base64.b64decode(delta)
|
|
except Exception as exc:
|
|
logger.warning(f"Failed to decode DashScope audio delta: {exc}")
|
|
return b""
|
|
if isinstance(delta, (bytes, bytearray)):
|
|
return bytes(delta)
|
|
return b""
|
|
|
|
def _resample_if_needed(self, pcm: bytes, state: Any) -> Tuple[bytes, Any]:
|
|
if self.sample_rate == self.PROVIDER_SAMPLE_RATE:
|
|
return pcm, state
|
|
try:
|
|
converted, next_state = audioop.ratecv(
|
|
pcm,
|
|
2, # 16-bit PCM
|
|
1, # mono
|
|
self.PROVIDER_SAMPLE_RATE,
|
|
self.sample_rate,
|
|
state,
|
|
)
|
|
return converted, next_state
|
|
except Exception as exc:
|
|
logger.warning(f"DashScope audio resample failed: {exc}; returning original sample rate data")
|
|
return pcm, state
|
|
|
|
@staticmethod
|
|
def _format_error_event(event: Dict[str, Any]) -> str:
|
|
err = event.get("error")
|
|
if isinstance(err, dict):
|
|
code = str(err.get("code") or "").strip()
|
|
message = str(err.get("message") or "").strip()
|
|
if code and message:
|
|
return f"{code}: {message}"
|
|
return message or str(err)
|
|
return str(err or "DashScope realtime TTS error")
|