"""DashScope realtime TTS service. Implements DashScope's Qwen realtime TTS protocol via the official SDK. """ import asyncio import audioop import base64 import json import os from typing import Any, AsyncIterator, Dict, Optional, Tuple from loguru import logger from services.base import BaseTTSService, ServiceState, TTSChunk try: import dashscope from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback DASHSCOPE_SDK_AVAILABLE = True except ImportError: dashscope = None # type: ignore[assignment] AudioFormat = None # type: ignore[assignment] QwenTtsRealtime = None # type: ignore[assignment] DASHSCOPE_SDK_AVAILABLE = False class QwenTtsRealtimeCallback: # type: ignore[no-redef] """Fallback callback base when DashScope SDK is unavailable.""" pass class _RealtimeEventCallback(QwenTtsRealtimeCallback): """Bridge SDK callback events into an asyncio queue.""" def __init__(self, loop: asyncio.AbstractEventLoop, queue: "asyncio.Queue[Dict[str, Any]]"): super().__init__() self._loop = loop self._queue = queue def _push(self, event: Dict[str, Any]) -> None: try: self._loop.call_soon_threadsafe(self._queue.put_nowait, event) except RuntimeError: return def on_open(self) -> None: self._push({"type": "session.open"}) def on_close(self, code: int, reason: str) -> None: self._push({"type": "__close__", "code": code, "reason": reason}) def on_error(self, message: str) -> None: self._push({"type": "error", "error": {"message": str(message)}}) def on_event(self, event: Any) -> None: if isinstance(event, dict): payload = event elif isinstance(event, str): try: payload = json.loads(event) except json.JSONDecodeError: payload = {"type": "raw", "message": event} else: payload = {"type": "raw", "message": str(event)} self._push(payload) def on_data(self, data: bytes) -> None: # Some SDK versions provide audio via on_data directly. self._push({"type": "response.audio.delta.raw", "audio": data}) class DashScopeTTSService(BaseTTSService): """DashScope realtime TTS service using Qwen Realtime protocol.""" DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" DEFAULT_MODEL = "qwen3-tts-flash-realtime" PROVIDER_SAMPLE_RATE = 24000 def __init__( self, api_key: Optional[str] = None, api_url: Optional[str] = None, voice: str = "Cherry", model: Optional[str] = None, mode: str = "server_commit", sample_rate: int = 16000, speed: float = 1.0, ): super().__init__(voice=voice, sample_rate=sample_rate, speed=speed) self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY") or os.getenv("TTS_API_KEY") self.api_url = ( api_url or os.getenv("DASHSCOPE_TTS_API_URL") or os.getenv("TTS_API_URL") or self.DEFAULT_WS_URL ) self.model = model or os.getenv("DASHSCOPE_TTS_MODEL") or self.DEFAULT_MODEL normalized_mode = str(mode or "").strip().lower() if normalized_mode not in {"server_commit", "commit"}: logger.warning(f"Unknown DashScope mode '{mode}', fallback to server_commit") normalized_mode = "server_commit" self.mode = normalized_mode self._client: Optional[Any] = None self._event_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue() self._callback: Optional[_RealtimeEventCallback] = None self._cancel_event = asyncio.Event() self._synthesis_lock = asyncio.Lock() async def connect(self) -> None: if not DASHSCOPE_SDK_AVAILABLE: raise RuntimeError("dashscope package not installed; install with `pip install dashscope`") if not self.api_key: raise ValueError("DashScope API key not provided. Configure agent.tts.api_key in YAML.") loop = asyncio.get_running_loop() self._callback = _RealtimeEventCallback(loop=loop, queue=self._event_queue) # The official Python SDK docs set key via global `dashscope.api_key`; # some SDK versions do not accept `api_key=` in QwenTtsRealtime ctor. if dashscope is not None: dashscope.api_key = self.api_key self._client = self._create_realtime_client(self._callback) await asyncio.to_thread(self._client.connect) await asyncio.to_thread( self._client.update_session, voice=self.voice, response_format=AudioFormat.PCM_24000HZ_MONO_16BIT, mode=self.mode, ) await self._wait_for_session_ready() self.state = ServiceState.CONNECTED logger.info( "DashScope realtime TTS service ready: " f"voice={self.voice}, model={self.model}, mode={self.mode}" ) def _create_realtime_client(self, callback: _RealtimeEventCallback) -> Any: init_kwargs = { "model": self.model, "callback": callback, "url": self.api_url, } try: return QwenTtsRealtime( # type: ignore[misc] api_key=self.api_key, **init_kwargs, ) except TypeError as exc: if "api_key" not in str(exc): raise logger.debug( "QwenTtsRealtime does not support `api_key` ctor arg; " "falling back to global dashscope.api_key auth" ) return QwenTtsRealtime(**init_kwargs) # type: ignore[misc] async def disconnect(self) -> None: self._cancel_event.set() if self._client: close_fn = getattr(self._client, "close", None) if callable(close_fn): await asyncio.to_thread(close_fn) self._client = None self._drain_event_queue() self.state = ServiceState.DISCONNECTED logger.info("DashScope realtime TTS service disconnected") async def synthesize(self, text: str) -> bytes: audio = b"" async for chunk in self.synthesize_stream(text): audio += chunk.audio return audio async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: if not self._client: raise RuntimeError("DashScope TTS service not connected") if not text.strip(): return async with self._synthesis_lock: self._cancel_event.clear() self._drain_event_queue() await self._clear_appended_text() await asyncio.to_thread(self._client.append_text, text) if self.mode == "commit": await asyncio.to_thread(self._client.commit) chunk_size = max(1, self.sample_rate * 2 // 10) # 100ms buffer = b"" pending_chunk: Optional[bytes] = None resample_state: Any = None while True: timeout = 8.0 if self._cancel_event.is_set() else 20.0 event = await self._next_event(timeout=timeout) event_type = str(event.get("type") or "").strip() if event_type in {"response.audio.delta", "response.audio.delta.raw"}: if self._cancel_event.is_set(): continue pcm = self._decode_audio_event(event) if not pcm: continue pcm, resample_state = self._resample_if_needed(pcm, resample_state) if not pcm: continue buffer += pcm while len(buffer) >= chunk_size: audio_chunk = buffer[:chunk_size] buffer = buffer[chunk_size:] if pending_chunk is not None: yield TTSChunk( audio=pending_chunk, sample_rate=self.sample_rate, is_final=False, ) pending_chunk = audio_chunk continue if event_type == "response.done": break if event_type == "error": raise RuntimeError(self._format_error_event(event)) if event_type == "__close__": reason = str(event.get("reason") or "unknown") raise RuntimeError(f"DashScope TTS websocket closed unexpectedly: {reason}") if self._cancel_event.is_set(): return if pending_chunk is not None: if buffer: yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False) pending_chunk = None else: yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True) pending_chunk = None if buffer: yield TTSChunk(audio=buffer, sample_rate=self.sample_rate, is_final=True) async def cancel(self) -> None: self._cancel_event.set() if self.mode == "commit": await self._clear_appended_text() return if not self._client: return cancel_fn = ( getattr(self._client, "cancel_response", None) or getattr(self._client, "cancel", None) ) if callable(cancel_fn): try: await asyncio.to_thread(cancel_fn) except Exception as exc: logger.debug(f"DashScope cancel failed: {exc}") async def _wait_for_session_ready(self) -> None: try: while True: event = await self._next_event(timeout=8.0) event_type = str(event.get("type") or "").strip() if event_type in {"session.updated", "session.open"}: return if event_type == "error": raise RuntimeError(self._format_error_event(event)) except asyncio.TimeoutError: logger.debug("DashScope session update event timeout; continuing with active websocket") async def _clear_appended_text(self) -> None: if self.mode != "commit": return if not self._client: return clear_fn = getattr(self._client, "clear_appended_text", None) if callable(clear_fn): try: await asyncio.to_thread(clear_fn) except Exception as exc: logger.debug(f"DashScope clear_appended_text failed: {exc}") async def _next_event(self, timeout: float) -> Dict[str, Any]: event = await asyncio.wait_for(self._event_queue.get(), timeout=timeout) if isinstance(event, dict): return event return {"type": "raw", "message": str(event)} def _drain_event_queue(self) -> None: while True: try: self._event_queue.get_nowait() except asyncio.QueueEmpty: break def _decode_audio_event(self, event: Dict[str, Any]) -> bytes: event_type = str(event.get("type") or "") if event_type == "response.audio.delta.raw": audio = event.get("audio") if isinstance(audio, (bytes, bytearray)): return bytes(audio) return b"" delta = event.get("delta") if isinstance(delta, str): try: return base64.b64decode(delta) except Exception as exc: logger.warning(f"Failed to decode DashScope audio delta: {exc}") return b"" if isinstance(delta, (bytes, bytearray)): return bytes(delta) return b"" def _resample_if_needed(self, pcm: bytes, state: Any) -> Tuple[bytes, Any]: if self.sample_rate == self.PROVIDER_SAMPLE_RATE: return pcm, state try: converted, next_state = audioop.ratecv( pcm, 2, # 16-bit PCM 1, # mono self.PROVIDER_SAMPLE_RATE, self.sample_rate, state, ) return converted, next_state except Exception as exc: logger.warning(f"DashScope audio resample failed: {exc}; returning original sample rate data") return pcm, state @staticmethod def _format_error_event(event: Dict[str, Any]) -> str: err = event.get("error") if isinstance(err, dict): code = str(err.get("code") or "").strip() message = str(err.get("message") or "").strip() if code and message: return f"{code}: {message}" return message or str(err) return str(err or "DashScope realtime TTS error")