Refactor project structure and enhance backend integration
- Expanded package inclusion in `pyproject.toml` to support new modules. - Introduced new `adapters` and `protocol` packages for better organization. - Added backend adapter implementations for control plane integration. - Updated main application imports to reflect new package structure. - Removed deprecated core components and adjusted documentation accordingly. - Enhanced architecture documentation to clarify the new runtime and integration layers.
This commit is contained in:
352
engine/providers/tts/dashscope.py
Normal file
352
engine/providers/tts/dashscope.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""DashScope realtime TTS service.
|
||||
|
||||
Implements DashScope's Qwen realtime TTS protocol via the official SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import audioop
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseTTSService, ServiceState, TTSChunk
|
||||
|
||||
try:
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback
|
||||
|
||||
DASHSCOPE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
dashscope = None # type: ignore[assignment]
|
||||
AudioFormat = None # type: ignore[assignment]
|
||||
QwenTtsRealtime = None # type: ignore[assignment]
|
||||
DASHSCOPE_SDK_AVAILABLE = False
|
||||
|
||||
class QwenTtsRealtimeCallback: # type: ignore[no-redef]
|
||||
"""Fallback callback base when DashScope SDK is unavailable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _RealtimeEventCallback(QwenTtsRealtimeCallback):
|
||||
"""Bridge SDK callback events into an asyncio queue."""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, queue: "asyncio.Queue[Dict[str, Any]]"):
|
||||
super().__init__()
|
||||
self._loop = loop
|
||||
self._queue = queue
|
||||
|
||||
def _push(self, event: Dict[str, Any]) -> None:
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
def on_open(self) -> None:
|
||||
self._push({"type": "session.open"})
|
||||
|
||||
def on_close(self, code: int, reason: str) -> None:
|
||||
self._push({"type": "__close__", "code": code, "reason": reason})
|
||||
|
||||
def on_error(self, message: str) -> None:
|
||||
self._push({"type": "error", "error": {"message": str(message)}})
|
||||
|
||||
def on_event(self, event: Any) -> None:
|
||||
if isinstance(event, dict):
|
||||
payload = event
|
||||
elif isinstance(event, str):
|
||||
try:
|
||||
payload = json.loads(event)
|
||||
except json.JSONDecodeError:
|
||||
payload = {"type": "raw", "message": event}
|
||||
else:
|
||||
payload = {"type": "raw", "message": str(event)}
|
||||
self._push(payload)
|
||||
|
||||
def on_data(self, data: bytes) -> None:
|
||||
# Some SDK versions provide audio via on_data directly.
|
||||
self._push({"type": "response.audio.delta.raw", "audio": data})
|
||||
|
||||
|
||||
class DashScopeTTSService(BaseTTSService):
|
||||
"""DashScope realtime TTS service using Qwen Realtime protocol."""
|
||||
|
||||
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
DEFAULT_MODEL = "qwen3-tts-flash-realtime"
|
||||
PROVIDER_SAMPLE_RATE = 24000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "Cherry",
|
||||
model: Optional[str] = None,
|
||||
mode: str = "commit",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
self.api_key = api_key
|
||||
self.api_url = (
|
||||
api_url
|
||||
or os.getenv("DASHSCOPE_TTS_API_URL")
|
||||
or os.getenv("TTS_API_URL")
|
||||
or self.DEFAULT_WS_URL
|
||||
)
|
||||
self.model = model or os.getenv("DASHSCOPE_TTS_MODEL") or self.DEFAULT_MODEL
|
||||
|
||||
normalized_mode = str(mode or "").strip().lower()
|
||||
if normalized_mode not in {"server_commit", "commit"}:
|
||||
logger.warning(f"Unknown DashScope mode '{mode}', fallback to server_commit")
|
||||
normalized_mode = "server_commit"
|
||||
self.mode = normalized_mode
|
||||
|
||||
self._client: Optional[Any] = None
|
||||
self._event_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
||||
self._callback: Optional[_RealtimeEventCallback] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._synthesis_lock = asyncio.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
raise RuntimeError("dashscope package not installed; install with `pip install dashscope`")
|
||||
if not self.api_key:
|
||||
raise ValueError("DashScope API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
self._callback = _RealtimeEventCallback(loop=loop, queue=self._event_queue)
|
||||
# The official Python SDK docs set key via global `dashscope.api_key`;
|
||||
# some SDK versions do not accept `api_key=` in QwenTtsRealtime ctor.
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = self.api_key
|
||||
self._client = self._create_realtime_client(self._callback)
|
||||
|
||||
await asyncio.to_thread(self._client.connect)
|
||||
await asyncio.to_thread(
|
||||
self._client.update_session,
|
||||
voice=self.voice,
|
||||
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
|
||||
mode=self.mode,
|
||||
)
|
||||
await self._wait_for_session_ready()
|
||||
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(
|
||||
"DashScope realtime TTS service ready: "
|
||||
f"voice={self.voice}, model={self.model}, mode={self.mode}"
|
||||
)
|
||||
|
||||
def _create_realtime_client(self, callback: _RealtimeEventCallback) -> Any:
|
||||
init_kwargs = {
|
||||
"model": self.model,
|
||||
"callback": callback,
|
||||
"url": self.api_url,
|
||||
}
|
||||
try:
|
||||
return QwenTtsRealtime( # type: ignore[misc]
|
||||
api_key=self.api_key,
|
||||
**init_kwargs,
|
||||
)
|
||||
except TypeError as exc:
|
||||
if "api_key" not in str(exc):
|
||||
raise
|
||||
logger.debug(
|
||||
"QwenTtsRealtime does not support `api_key` ctor arg; "
|
||||
"falling back to global dashscope.api_key auth"
|
||||
)
|
||||
return QwenTtsRealtime(**init_kwargs) # type: ignore[misc]
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self._client:
|
||||
close_fn = getattr(self._client, "close", None)
|
||||
if callable(close_fn):
|
||||
await asyncio.to_thread(close_fn)
|
||||
self._client = None
|
||||
self._drain_event_queue()
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("DashScope realtime TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
audio = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio += chunk.audio
|
||||
return audio
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
if not self._client:
|
||||
raise RuntimeError("DashScope TTS service not connected")
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
async with self._synthesis_lock:
|
||||
self._cancel_event.clear()
|
||||
self._drain_event_queue()
|
||||
|
||||
await self._clear_appended_text()
|
||||
await asyncio.to_thread(self._client.append_text, text)
|
||||
if self.mode == "commit":
|
||||
await asyncio.to_thread(self._client.commit)
|
||||
|
||||
chunk_size = max(1, self.sample_rate * 2 // 10) # 100ms
|
||||
buffer = b""
|
||||
pending_chunk: Optional[bytes] = None
|
||||
resample_state: Any = None
|
||||
|
||||
while True:
|
||||
timeout = 8.0 if self._cancel_event.is_set() else 20.0
|
||||
event = await self._next_event(timeout=timeout)
|
||||
event_type = str(event.get("type") or "").strip()
|
||||
|
||||
if event_type in {"response.audio.delta", "response.audio.delta.raw"}:
|
||||
if self._cancel_event.is_set():
|
||||
continue
|
||||
|
||||
pcm = self._decode_audio_event(event)
|
||||
if not pcm:
|
||||
continue
|
||||
pcm, resample_state = self._resample_if_needed(pcm, resample_state)
|
||||
if not pcm:
|
||||
continue
|
||||
|
||||
buffer += pcm
|
||||
while len(buffer) >= chunk_size:
|
||||
audio_chunk = buffer[:chunk_size]
|
||||
buffer = buffer[chunk_size:]
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False,
|
||||
)
|
||||
pending_chunk = audio_chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.done":
|
||||
break
|
||||
|
||||
if event_type == "error":
|
||||
raise RuntimeError(self._format_error_event(event))
|
||||
|
||||
if event_type == "__close__":
|
||||
reason = str(event.get("reason") or "unknown")
|
||||
raise RuntimeError(f"DashScope TTS websocket closed unexpectedly: {reason}")
|
||||
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
if pending_chunk is not None:
|
||||
if buffer:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
|
||||
pending_chunk = None
|
||||
|
||||
if buffer:
|
||||
yield TTSChunk(audio=buffer, sample_rate=self.sample_rate, is_final=True)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self.mode == "commit":
|
||||
await self._clear_appended_text()
|
||||
return
|
||||
|
||||
if not self._client:
|
||||
return
|
||||
cancel_fn = (
|
||||
getattr(self._client, "cancel_response", None)
|
||||
or getattr(self._client, "cancel", None)
|
||||
)
|
||||
if callable(cancel_fn):
|
||||
try:
|
||||
await asyncio.to_thread(cancel_fn)
|
||||
except Exception as exc:
|
||||
logger.debug(f"DashScope cancel failed: {exc}")
|
||||
|
||||
async def _wait_for_session_ready(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
event = await self._next_event(timeout=8.0)
|
||||
event_type = str(event.get("type") or "").strip()
|
||||
if event_type in {"session.updated", "session.open"}:
|
||||
return
|
||||
if event_type == "error":
|
||||
raise RuntimeError(self._format_error_event(event))
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("DashScope session update event timeout; continuing with active websocket")
|
||||
|
||||
async def _clear_appended_text(self) -> None:
|
||||
if self.mode != "commit":
|
||||
return
|
||||
if not self._client:
|
||||
return
|
||||
clear_fn = getattr(self._client, "clear_appended_text", None)
|
||||
if callable(clear_fn):
|
||||
try:
|
||||
await asyncio.to_thread(clear_fn)
|
||||
except Exception as exc:
|
||||
logger.debug(f"DashScope clear_appended_text failed: {exc}")
|
||||
|
||||
async def _next_event(self, timeout: float) -> Dict[str, Any]:
|
||||
event = await asyncio.wait_for(self._event_queue.get(), timeout=timeout)
|
||||
if isinstance(event, dict):
|
||||
return event
|
||||
return {"type": "raw", "message": str(event)}
|
||||
|
||||
def _drain_event_queue(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
self._event_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
def _decode_audio_event(self, event: Dict[str, Any]) -> bytes:
|
||||
event_type = str(event.get("type") or "")
|
||||
if event_type == "response.audio.delta.raw":
|
||||
audio = event.get("audio")
|
||||
if isinstance(audio, (bytes, bytearray)):
|
||||
return bytes(audio)
|
||||
return b""
|
||||
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str):
|
||||
try:
|
||||
return base64.b64decode(delta)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to decode DashScope audio delta: {exc}")
|
||||
return b""
|
||||
if isinstance(delta, (bytes, bytearray)):
|
||||
return bytes(delta)
|
||||
return b""
|
||||
|
||||
def _resample_if_needed(self, pcm: bytes, state: Any) -> Tuple[bytes, Any]:
|
||||
if self.sample_rate == self.PROVIDER_SAMPLE_RATE:
|
||||
return pcm, state
|
||||
try:
|
||||
converted, next_state = audioop.ratecv(
|
||||
pcm,
|
||||
2, # 16-bit PCM
|
||||
1, # mono
|
||||
self.PROVIDER_SAMPLE_RATE,
|
||||
self.sample_rate,
|
||||
state,
|
||||
)
|
||||
return converted, next_state
|
||||
except Exception as exc:
|
||||
logger.warning(f"DashScope audio resample failed: {exc}; returning original sample rate data")
|
||||
return pcm, state
|
||||
|
||||
@staticmethod
|
||||
def _format_error_event(event: Dict[str, Any]) -> str:
|
||||
err = event.get("error")
|
||||
if isinstance(err, dict):
|
||||
code = str(err.get("code") or "").strip()
|
||||
message = str(err.get("message") or "").strip()
|
||||
if code and message:
|
||||
return f"{code}: {message}"
|
||||
return message or str(err)
|
||||
return str(err or "DashScope realtime TTS error")
|
||||
Reference in New Issue
Block a user