Files
AI-VideoAssistant/engine/services/dashscope_tts.py
Xin Wang 935f2fbd1f Refactor assistant configuration management and update documentation
- Removed legacy agent profile settings from the .env.example and README, streamlining the configuration process.
- Introduced a new local YAML configuration adapter for assistant settings, allowing for easier management of assistant profiles.
- Updated backend integration documentation to clarify the behavior of assistant config sourcing based on backend URL settings.
- Adjusted various service implementations to directly utilize API keys from the new configuration structure.
- Enhanced test coverage for the new local YAML adapter and its integration with backend services.
2026-03-05 21:24:15 +08:00

353 lines
13 KiB
Python

"""DashScope realtime TTS service.
Implements DashScope's Qwen realtime TTS protocol via the official SDK.
"""
import asyncio
import audioop
import base64
import json
import os
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from loguru import logger
from services.base import BaseTTSService, ServiceState, TTSChunk
try:
import dashscope
from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback
DASHSCOPE_SDK_AVAILABLE = True
except ImportError:
dashscope = None # type: ignore[assignment]
AudioFormat = None # type: ignore[assignment]
QwenTtsRealtime = None # type: ignore[assignment]
DASHSCOPE_SDK_AVAILABLE = False
class QwenTtsRealtimeCallback: # type: ignore[no-redef]
"""Fallback callback base when DashScope SDK is unavailable."""
pass
class _RealtimeEventCallback(QwenTtsRealtimeCallback):
"""Bridge SDK callback events into an asyncio queue."""
def __init__(self, loop: asyncio.AbstractEventLoop, queue: "asyncio.Queue[Dict[str, Any]]"):
super().__init__()
self._loop = loop
self._queue = queue
def _push(self, event: Dict[str, Any]) -> None:
try:
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
except RuntimeError:
return
def on_open(self) -> None:
self._push({"type": "session.open"})
def on_close(self, code: int, reason: str) -> None:
self._push({"type": "__close__", "code": code, "reason": reason})
def on_error(self, message: str) -> None:
self._push({"type": "error", "error": {"message": str(message)}})
def on_event(self, event: Any) -> None:
if isinstance(event, dict):
payload = event
elif isinstance(event, str):
try:
payload = json.loads(event)
except json.JSONDecodeError:
payload = {"type": "raw", "message": event}
else:
payload = {"type": "raw", "message": str(event)}
self._push(payload)
def on_data(self, data: bytes) -> None:
# Some SDK versions provide audio via on_data directly.
self._push({"type": "response.audio.delta.raw", "audio": data})
class DashScopeTTSService(BaseTTSService):
"""DashScope realtime TTS service using Qwen Realtime protocol."""
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
DEFAULT_MODEL = "qwen3-tts-flash-realtime"
PROVIDER_SAMPLE_RATE = 24000
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
voice: str = "Cherry",
model: Optional[str] = None,
mode: str = "commit",
sample_rate: int = 16000,
speed: float = 1.0,
):
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
self.api_key = api_key
self.api_url = (
api_url
or os.getenv("DASHSCOPE_TTS_API_URL")
or os.getenv("TTS_API_URL")
or self.DEFAULT_WS_URL
)
self.model = model or os.getenv("DASHSCOPE_TTS_MODEL") or self.DEFAULT_MODEL
normalized_mode = str(mode or "").strip().lower()
if normalized_mode not in {"server_commit", "commit"}:
logger.warning(f"Unknown DashScope mode '{mode}', fallback to server_commit")
normalized_mode = "server_commit"
self.mode = normalized_mode
self._client: Optional[Any] = None
self._event_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
self._callback: Optional[_RealtimeEventCallback] = None
self._cancel_event = asyncio.Event()
self._synthesis_lock = asyncio.Lock()
async def connect(self) -> None:
if not DASHSCOPE_SDK_AVAILABLE:
raise RuntimeError("dashscope package not installed; install with `pip install dashscope`")
if not self.api_key:
raise ValueError("DashScope API key not provided. Configure agent.tts.api_key in YAML.")
loop = asyncio.get_running_loop()
self._callback = _RealtimeEventCallback(loop=loop, queue=self._event_queue)
# The official Python SDK docs set key via global `dashscope.api_key`;
# some SDK versions do not accept `api_key=` in QwenTtsRealtime ctor.
if dashscope is not None:
dashscope.api_key = self.api_key
self._client = self._create_realtime_client(self._callback)
await asyncio.to_thread(self._client.connect)
await asyncio.to_thread(
self._client.update_session,
voice=self.voice,
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
mode=self.mode,
)
await self._wait_for_session_ready()
self.state = ServiceState.CONNECTED
logger.info(
"DashScope realtime TTS service ready: "
f"voice={self.voice}, model={self.model}, mode={self.mode}"
)
def _create_realtime_client(self, callback: _RealtimeEventCallback) -> Any:
init_kwargs = {
"model": self.model,
"callback": callback,
"url": self.api_url,
}
try:
return QwenTtsRealtime( # type: ignore[misc]
api_key=self.api_key,
**init_kwargs,
)
except TypeError as exc:
if "api_key" not in str(exc):
raise
logger.debug(
"QwenTtsRealtime does not support `api_key` ctor arg; "
"falling back to global dashscope.api_key auth"
)
return QwenTtsRealtime(**init_kwargs) # type: ignore[misc]
async def disconnect(self) -> None:
self._cancel_event.set()
if self._client:
close_fn = getattr(self._client, "close", None)
if callable(close_fn):
await asyncio.to_thread(close_fn)
self._client = None
self._drain_event_queue()
self.state = ServiceState.DISCONNECTED
logger.info("DashScope realtime TTS service disconnected")
async def synthesize(self, text: str) -> bytes:
audio = b""
async for chunk in self.synthesize_stream(text):
audio += chunk.audio
return audio
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
if not self._client:
raise RuntimeError("DashScope TTS service not connected")
if not text.strip():
return
async with self._synthesis_lock:
self._cancel_event.clear()
self._drain_event_queue()
await self._clear_appended_text()
await asyncio.to_thread(self._client.append_text, text)
if self.mode == "commit":
await asyncio.to_thread(self._client.commit)
chunk_size = max(1, self.sample_rate * 2 // 10) # 100ms
buffer = b""
pending_chunk: Optional[bytes] = None
resample_state: Any = None
while True:
timeout = 8.0 if self._cancel_event.is_set() else 20.0
event = await self._next_event(timeout=timeout)
event_type = str(event.get("type") or "").strip()
if event_type in {"response.audio.delta", "response.audio.delta.raw"}:
if self._cancel_event.is_set():
continue
pcm = self._decode_audio_event(event)
if not pcm:
continue
pcm, resample_state = self._resample_if_needed(pcm, resample_state)
if not pcm:
continue
buffer += pcm
while len(buffer) >= chunk_size:
audio_chunk = buffer[:chunk_size]
buffer = buffer[chunk_size:]
if pending_chunk is not None:
yield TTSChunk(
audio=pending_chunk,
sample_rate=self.sample_rate,
is_final=False,
)
pending_chunk = audio_chunk
continue
if event_type == "response.done":
break
if event_type == "error":
raise RuntimeError(self._format_error_event(event))
if event_type == "__close__":
reason = str(event.get("reason") or "unknown")
raise RuntimeError(f"DashScope TTS websocket closed unexpectedly: {reason}")
if self._cancel_event.is_set():
return
if pending_chunk is not None:
if buffer:
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
pending_chunk = None
else:
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
pending_chunk = None
if buffer:
yield TTSChunk(audio=buffer, sample_rate=self.sample_rate, is_final=True)
async def cancel(self) -> None:
self._cancel_event.set()
if self.mode == "commit":
await self._clear_appended_text()
return
if not self._client:
return
cancel_fn = (
getattr(self._client, "cancel_response", None)
or getattr(self._client, "cancel", None)
)
if callable(cancel_fn):
try:
await asyncio.to_thread(cancel_fn)
except Exception as exc:
logger.debug(f"DashScope cancel failed: {exc}")
async def _wait_for_session_ready(self) -> None:
try:
while True:
event = await self._next_event(timeout=8.0)
event_type = str(event.get("type") or "").strip()
if event_type in {"session.updated", "session.open"}:
return
if event_type == "error":
raise RuntimeError(self._format_error_event(event))
except asyncio.TimeoutError:
logger.debug("DashScope session update event timeout; continuing with active websocket")
async def _clear_appended_text(self) -> None:
if self.mode != "commit":
return
if not self._client:
return
clear_fn = getattr(self._client, "clear_appended_text", None)
if callable(clear_fn):
try:
await asyncio.to_thread(clear_fn)
except Exception as exc:
logger.debug(f"DashScope clear_appended_text failed: {exc}")
async def _next_event(self, timeout: float) -> Dict[str, Any]:
event = await asyncio.wait_for(self._event_queue.get(), timeout=timeout)
if isinstance(event, dict):
return event
return {"type": "raw", "message": str(event)}
def _drain_event_queue(self) -> None:
while True:
try:
self._event_queue.get_nowait()
except asyncio.QueueEmpty:
break
def _decode_audio_event(self, event: Dict[str, Any]) -> bytes:
event_type = str(event.get("type") or "")
if event_type == "response.audio.delta.raw":
audio = event.get("audio")
if isinstance(audio, (bytes, bytearray)):
return bytes(audio)
return b""
delta = event.get("delta")
if isinstance(delta, str):
try:
return base64.b64decode(delta)
except Exception as exc:
logger.warning(f"Failed to decode DashScope audio delta: {exc}")
return b""
if isinstance(delta, (bytes, bytearray)):
return bytes(delta)
return b""
def _resample_if_needed(self, pcm: bytes, state: Any) -> Tuple[bytes, Any]:
if self.sample_rate == self.PROVIDER_SAMPLE_RATE:
return pcm, state
try:
converted, next_state = audioop.ratecv(
pcm,
2, # 16-bit PCM
1, # mono
self.PROVIDER_SAMPLE_RATE,
self.sample_rate,
state,
)
return converted, next_state
except Exception as exc:
logger.warning(f"DashScope audio resample failed: {exc}; returning original sample rate data")
return pcm, state
@staticmethod
def _format_error_event(event: Dict[str, Any]) -> str:
err = event.get("error")
if isinstance(err, dict):
code = str(err.get("code") or "").strip()
message = str(err.get("message") or "").strip()
if code and message:
return f"{code}: {message}"
return message or str(err)
return str(err or "DashScope realtime TTS error")