diff --git a/engine/README.md b/engine/README.md index 3353270..9d0949b 100644 --- a/engine/README.md +++ b/engine/README.md @@ -37,6 +37,10 @@ Agent 配置路径优先级 - Agent 相关配置是严格模式:YAML 缺少必须项会直接报错,不会回退到 `.env` 或代码默认值。 - 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`。 - `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider`、`api_key`、`api_url`、`model` 配置。 +- `agent.tts.provider` 现支持 `dashscope`(Realtime 协议,非 OpenAI-compatible);默认 URL 为 `wss://dashscope.aliyuncs.com/api-ws/v1/realtime`,默认模型为 `qwen3-tts-flash-realtime`。 +- `agent.tts.dashscope_mode`(兼容旧写法 `agent.tts.mode`)支持 `commit | server_commit`,且仅在 `provider=dashscope` 时生效: + - `commit`:Engine 先按句切分,再逐句提交给 DashScope。 + - `server_commit`:Engine 不再逐句切分,由 DashScope 对整段文本自行切分。 - 现在支持在 Agent YAML 中配置 `agent.tools`(列表),用于声明运行时可调用工具。 - 工具配置示例见 `config/agents/tools.yaml`。 diff --git a/engine/app/config.py b/engine/app/config.py index 2d1a680..a441359 100644 --- a/engine/app/config.py +++ b/engine/app/config.py @@ -41,6 +41,8 @@ _AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = { "api_url": "tts_api_url", "model": "tts_model", "voice": "tts_voice", + "dashscope_mode": "tts_mode", + "mode": "tts_mode", "speed": "tts_speed", }, "asr": { @@ -80,6 +82,7 @@ _AGENT_SETTING_KEYS = { "tts_api_url", "tts_model", "tts_voice", + "tts_mode", "tts_speed", "asr_provider", "asr_api_key", @@ -120,7 +123,10 @@ _BASE_REQUIRED_AGENT_SETTING_KEYS = { "barge_in_min_duration_ms", "barge_in_silence_tolerance_ms", } -_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} +_OPENAI_COMPATIBLE_LLM_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} +_OPENAI_COMPATIBLE_TTS_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} +_DASHSCOPE_TTS_PROVIDERS = {"dashscope"} +_OPENAI_COMPATIBLE_ASR_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str: @@ -285,21 +291,24 @@ def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]: missing.add(key) llm_provider = _normalized_provider(overrides, "llm_provider", "openai") - if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai": + if llm_provider in _OPENAI_COMPATIBLE_LLM_PROVIDERS or llm_provider == "openai": if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")): missing.add("llm_api_key") tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible") - if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS: + if tts_provider in _OPENAI_COMPATIBLE_TTS_PROVIDERS: if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")): missing.add("tts_api_key") if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")): missing.add("tts_api_url") if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")): missing.add("tts_model") + elif tts_provider in _DASHSCOPE_TTS_PROVIDERS: + if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")): + missing.add("tts_api_key") asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible") - if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS: + if asr_provider in _OPENAI_COMPATIBLE_ASR_PROVIDERS: if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")): missing.add("asr_api_key") if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")): @@ -401,12 +410,16 @@ class Settings(BaseSettings): # TTS Configuration tts_provider: str = Field( default="openai_compatible", - description="TTS provider (edge, openai_compatible, siliconflow)" + description="TTS provider (edge, openai_compatible, siliconflow, dashscope)" ) tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key") tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL") tts_model: Optional[str] = Field(default=None, description="TTS model name") tts_voice: str = Field(default="anna", description="TTS voice name") + tts_mode: str = Field( + default="commit", + description="DashScope-only TTS mode (commit, server_commit). Ignored for non-dashscope providers." + ) tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier") # ASR Configuration diff --git a/engine/config/agents/example.yaml b/engine/config/agents/example.yaml index 114830e..dd0e927 100644 --- a/engine/config/agents/example.yaml +++ b/engine/config/agents/example.yaml @@ -21,7 +21,12 @@ agent: api_url: https://api.qnaigc.com/v1 tts: - # provider: edge | openai_compatible | siliconflow + # provider: edge | openai_compatible | siliconflow | dashscope + # dashscope defaults (if omitted): + # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + # model: qwen3-tts-flash-realtime + # dashscope_mode: commit (engine splits) | server_commit (dashscope splits) + # note: dashscope_mode/mode is ONLY used when provider=dashscope. provider: openai_compatible api_key: your_tts_api_key api_url: https://api.siliconflow.cn/v1/audio/speech diff --git a/engine/config/agents/tools.yaml b/engine/config/agents/tools.yaml index 9734bff..4d8bd72 100644 --- a/engine/config/agents/tools.yaml +++ b/engine/config/agents/tools.yaml @@ -18,7 +18,12 @@ agent: api_url: https://api.qnaigc.com/v1 tts: - # provider: edge | openai_compatible | siliconflow + # provider: edge | openai_compatible | siliconflow | dashscope + # dashscope defaults (if omitted): + # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + # model: qwen3-tts-flash-realtime + # dashscope_mode: commit (engine splits) | server_commit (dashscope splits) + # note: dashscope_mode/mode is ONLY used when provider=dashscope. provider: openai_compatible api_key: your_tts_api_key api_url: https://api.siliconflow.cn/v1/audio/speech diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 5722cea..1a10d9e 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -30,6 +30,7 @@ from processors.eou import EouDetector from processors.vad import SileroVAD, VADProcessor from services.asr import BufferedASRService from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent +from services.dashscope_tts import DashScopeTTSService from services.llm import MockLLMService, OpenAILLMService from services.openai_compatible_asr import OpenAICompatibleASRService from services.openai_compatible_tts import OpenAICompatibleTTSService @@ -349,6 +350,21 @@ class DuplexPipeline: if not output_mode: output_mode = "audio" if self._tts_output_enabled() else "text" + tts_model = str( + self._runtime_tts.get("model") + or settings.tts_model + or (self._default_dashscope_tts_model() if self._is_dashscope_tts_provider(tts_provider) else "") + ) + tts_config = { + "enabled": self._tts_output_enabled(), + "provider": tts_provider, + "model": tts_model, + "voice": str(self._runtime_tts.get("voice") or settings.tts_voice), + "speed": float(self._runtime_tts.get("speed") or settings.tts_speed), + } + if self._is_dashscope_tts_provider(tts_provider): + tts_config["mode"] = self._resolved_dashscope_tts_mode() + return { "output": {"mode": output_mode}, "services": { @@ -363,13 +379,7 @@ class DuplexPipeline: "interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms), "minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms), }, - "tts": { - "enabled": self._tts_output_enabled(), - "provider": tts_provider, - "model": str(self._runtime_tts.get("model") or settings.tts_model or ""), - "voice": str(self._runtime_tts.get("voice") or settings.tts_voice), - "speed": float(self._runtime_tts.get("speed") or settings.tts_speed), - }, + "tts": tts_config, }, "tools": { "allowlist": self._resolved_tool_allowlist(), @@ -484,6 +494,11 @@ class DuplexPipeline: normalized = str(provider or "").strip().lower() return normalized in {"openai_compatible", "openai-compatible", "siliconflow"} + @staticmethod + def _is_dashscope_tts_provider(provider: Any) -> bool: + normalized = str(provider or "").strip().lower() + return normalized == "dashscope" + @staticmethod def _is_llm_provider_supported(provider: Any) -> bool: normalized = str(provider or "").strip().lower() @@ -496,6 +511,28 @@ class DuplexPipeline: return "https://api.siliconflow.cn/v1" return None + @staticmethod + def _default_dashscope_tts_realtime_url() -> str: + return "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + + @staticmethod + def _default_dashscope_tts_model() -> str: + return "qwen3-tts-flash-realtime" + + def _resolved_dashscope_tts_mode(self) -> str: + raw_mode = str(self._runtime_tts.get("mode") or settings.tts_mode or "commit").strip().lower() + if raw_mode in {"commit", "server_commit"}: + return raw_mode + return "commit" + + def _use_engine_sentence_split_for_tts(self) -> bool: + tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).strip().lower() + if not self._is_dashscope_tts_provider(tts_provider): + return True + # DashScope commit mode is client-driven and expects engine-side segmentation. + # server_commit mode lets DashScope handle segmentation on appended text. + return self._resolved_dashscope_tts_mode() != "server_commit" + def _tts_output_enabled(self) -> bool: enabled = self._coerce_bool(self._runtime_tts.get("enabled")) if enabled is not None: @@ -610,8 +647,26 @@ class DuplexPipeline: tts_voice = self._runtime_tts.get("voice") or settings.tts_voice tts_model = self._runtime_tts.get("model") or settings.tts_model tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) + tts_mode = self._resolved_dashscope_tts_mode() + runtime_mode = str(self._runtime_tts.get("mode") or "").strip() + if runtime_mode and not self._is_dashscope_tts_provider(tts_provider): + logger.warning( + "services.tts.mode is DashScope-only and will be ignored " + f"for provider={tts_provider}" + ) - if self._is_openai_compatible_provider(tts_provider) and tts_api_key: + if self._is_dashscope_tts_provider(tts_provider) and tts_api_key: + self.tts_service = DashScopeTTSService( + api_key=tts_api_key, + api_url=tts_api_url or self._default_dashscope_tts_realtime_url(), + voice=tts_voice, + model=tts_model or self._default_dashscope_tts_model(), + mode=str(tts_mode), + sample_rate=settings.sample_rate, + speed=tts_speed + ) + logger.info("Using DashScope realtime TTS service") + elif self._is_openai_compatible_provider(tts_provider) and tts_api_key: self.tts_service = OpenAICompatibleTTSService( api_key=tts_api_key, api_url=tts_api_url, @@ -1379,6 +1434,7 @@ class DuplexPipeline: round_response = "" tool_calls: List[Dict[str, Any]] = [] allow_text_output = True + use_engine_sentence_split = self._use_engine_sentence_split_for_tts() async for raw_event in self.llm_service.generate_stream(messages): if self._interrupt_event.is_set(): @@ -1446,52 +1502,56 @@ class DuplexPipeline: ): await self._flush_pending_llm_delta() - while True: - split_result = extract_tts_sentence( - sentence_buffer, - end_chars=self._SENTENCE_END_CHARS, - trailing_chars=self._SENTENCE_TRAILING_CHARS, - closers=self._SENTENCE_CLOSERS, - min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS, - hold_trailing_at_buffer_end=True, - force=False, - ) - if not split_result: - break - sentence, sentence_buffer = split_result - if not sentence: - continue - - sentence = f"{pending_punctuation}{sentence}".strip() - pending_punctuation = "" - if not sentence: - continue - - if not has_spoken_content(sentence): - pending_punctuation = sentence - continue - - if self._tts_output_enabled() and not self._interrupt_event.is_set(): - if not first_audio_sent: - self._start_tts() - await self._send_event( - { - **ev( - "output.audio.start", - trackId=self.track_audio_out, - ) - }, - priority=10, - ) - first_audio_sent = True - - await self._speak_sentence( - sentence, - fade_in_ms=0, - fade_out_ms=8, + if use_engine_sentence_split: + while True: + split_result = extract_tts_sentence( + sentence_buffer, + end_chars=self._SENTENCE_END_CHARS, + trailing_chars=self._SENTENCE_TRAILING_CHARS, + closers=self._SENTENCE_CLOSERS, + min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS, + hold_trailing_at_buffer_end=True, + force=False, ) + if not split_result: + break + sentence, sentence_buffer = split_result + if not sentence: + continue - remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() + sentence = f"{pending_punctuation}{sentence}".strip() + pending_punctuation = "" + if not sentence: + continue + + if not has_spoken_content(sentence): + pending_punctuation = sentence + continue + + if self._tts_output_enabled() and not self._interrupt_event.is_set(): + if not first_audio_sent: + self._start_tts() + await self._send_event( + { + **ev( + "output.audio.start", + trackId=self.track_audio_out, + ) + }, + priority=10, + ) + first_audio_sent = True + + await self._speak_sentence( + sentence, + fade_in_ms=0, + fade_out_ms=8, + ) + + if use_engine_sentence_split: + remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() + else: + remaining_text = sentence_buffer.strip() await self._flush_pending_llm_delta() if ( self._tts_output_enabled() diff --git a/engine/requirements.txt b/engine/requirements.txt index d117414..a32b7d2 100644 --- a/engine/requirements.txt +++ b/engine/requirements.txt @@ -27,6 +27,7 @@ aiohttp>=3.9.1 # AI Services - LLM openai>=1.0.0 +dashscope>=1.25.11 # AI Services - TTS edge-tts>=6.1.0 diff --git a/engine/services/__init__.py b/engine/services/__init__.py index 0bab6b3..0e46834 100644 --- a/engine/services/__init__.py +++ b/engine/services/__init__.py @@ -13,6 +13,7 @@ from services.base import ( BaseTTSService, ) from services.llm import OpenAILLMService, MockLLMService +from services.dashscope_tts import DashScopeTTSService from services.tts import EdgeTTSService, MockTTSService from services.asr import BufferedASRService, MockASRService from services.openai_compatible_asr import OpenAICompatibleASRService, SiliconFlowASRService @@ -33,6 +34,7 @@ __all__ = [ "OpenAILLMService", "MockLLMService", # TTS + "DashScopeTTSService", "EdgeTTSService", "MockTTSService", # ASR diff --git a/engine/services/dashscope_tts.py b/engine/services/dashscope_tts.py new file mode 100644 index 0000000..ef01f3b --- /dev/null +++ b/engine/services/dashscope_tts.py @@ -0,0 +1,352 @@ +"""DashScope realtime TTS service. + +Implements DashScope's Qwen realtime TTS protocol via the official SDK. +""" + +import asyncio +import audioop +import base64 +import json +import os +from typing import Any, AsyncIterator, Dict, Optional, Tuple + +from loguru import logger + +from services.base import BaseTTSService, ServiceState, TTSChunk + +try: + import dashscope + from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback + + DASHSCOPE_SDK_AVAILABLE = True +except ImportError: + dashscope = None # type: ignore[assignment] + AudioFormat = None # type: ignore[assignment] + QwenTtsRealtime = None # type: ignore[assignment] + DASHSCOPE_SDK_AVAILABLE = False + + class QwenTtsRealtimeCallback: # type: ignore[no-redef] + """Fallback callback base when DashScope SDK is unavailable.""" + + pass + + +class _RealtimeEventCallback(QwenTtsRealtimeCallback): + """Bridge SDK callback events into an asyncio queue.""" + + def __init__(self, loop: asyncio.AbstractEventLoop, queue: "asyncio.Queue[Dict[str, Any]]"): + super().__init__() + self._loop = loop + self._queue = queue + + def _push(self, event: Dict[str, Any]) -> None: + try: + self._loop.call_soon_threadsafe(self._queue.put_nowait, event) + except RuntimeError: + return + + def on_open(self) -> None: + self._push({"type": "session.open"}) + + def on_close(self, code: int, reason: str) -> None: + self._push({"type": "__close__", "code": code, "reason": reason}) + + def on_error(self, message: str) -> None: + self._push({"type": "error", "error": {"message": str(message)}}) + + def on_event(self, event: Any) -> None: + if isinstance(event, dict): + payload = event + elif isinstance(event, str): + try: + payload = json.loads(event) + except json.JSONDecodeError: + payload = {"type": "raw", "message": event} + else: + payload = {"type": "raw", "message": str(event)} + self._push(payload) + + def on_data(self, data: bytes) -> None: + # Some SDK versions provide audio via on_data directly. + self._push({"type": "response.audio.delta.raw", "audio": data}) + + +class DashScopeTTSService(BaseTTSService): + """DashScope realtime TTS service using Qwen Realtime protocol.""" + + DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + DEFAULT_MODEL = "qwen3-tts-flash-realtime" + PROVIDER_SAMPLE_RATE = 24000 + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + voice: str = "Cherry", + model: Optional[str] = None, + mode: str = "server_commit", + sample_rate: int = 16000, + speed: float = 1.0, + ): + super().__init__(voice=voice, sample_rate=sample_rate, speed=speed) + self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY") or os.getenv("TTS_API_KEY") + self.api_url = ( + api_url + or os.getenv("DASHSCOPE_TTS_API_URL") + or os.getenv("TTS_API_URL") + or self.DEFAULT_WS_URL + ) + self.model = model or os.getenv("DASHSCOPE_TTS_MODEL") or self.DEFAULT_MODEL + + normalized_mode = str(mode or "").strip().lower() + if normalized_mode not in {"server_commit", "commit"}: + logger.warning(f"Unknown DashScope mode '{mode}', fallback to server_commit") + normalized_mode = "server_commit" + self.mode = normalized_mode + + self._client: Optional[Any] = None + self._event_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue() + self._callback: Optional[_RealtimeEventCallback] = None + self._cancel_event = asyncio.Event() + self._synthesis_lock = asyncio.Lock() + + async def connect(self) -> None: + if not DASHSCOPE_SDK_AVAILABLE: + raise RuntimeError("dashscope package not installed; install with `pip install dashscope`") + if not self.api_key: + raise ValueError("DashScope API key not provided. Configure agent.tts.api_key in YAML.") + + loop = asyncio.get_running_loop() + self._callback = _RealtimeEventCallback(loop=loop, queue=self._event_queue) + # The official Python SDK docs set key via global `dashscope.api_key`; + # some SDK versions do not accept `api_key=` in QwenTtsRealtime ctor. + if dashscope is not None: + dashscope.api_key = self.api_key + self._client = self._create_realtime_client(self._callback) + + await asyncio.to_thread(self._client.connect) + await asyncio.to_thread( + self._client.update_session, + voice=self.voice, + response_format=AudioFormat.PCM_24000HZ_MONO_16BIT, + mode=self.mode, + ) + await self._wait_for_session_ready() + + self.state = ServiceState.CONNECTED + logger.info( + "DashScope realtime TTS service ready: " + f"voice={self.voice}, model={self.model}, mode={self.mode}" + ) + + def _create_realtime_client(self, callback: _RealtimeEventCallback) -> Any: + init_kwargs = { + "model": self.model, + "callback": callback, + "url": self.api_url, + } + try: + return QwenTtsRealtime( # type: ignore[misc] + api_key=self.api_key, + **init_kwargs, + ) + except TypeError as exc: + if "api_key" not in str(exc): + raise + logger.debug( + "QwenTtsRealtime does not support `api_key` ctor arg; " + "falling back to global dashscope.api_key auth" + ) + return QwenTtsRealtime(**init_kwargs) # type: ignore[misc] + + async def disconnect(self) -> None: + self._cancel_event.set() + if self._client: + close_fn = getattr(self._client, "close", None) + if callable(close_fn): + await asyncio.to_thread(close_fn) + self._client = None + self._drain_event_queue() + self.state = ServiceState.DISCONNECTED + logger.info("DashScope realtime TTS service disconnected") + + async def synthesize(self, text: str) -> bytes: + audio = b"" + async for chunk in self.synthesize_stream(text): + audio += chunk.audio + return audio + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + if not self._client: + raise RuntimeError("DashScope TTS service not connected") + if not text.strip(): + return + + async with self._synthesis_lock: + self._cancel_event.clear() + self._drain_event_queue() + + await self._clear_appended_text() + await asyncio.to_thread(self._client.append_text, text) + if self.mode == "commit": + await asyncio.to_thread(self._client.commit) + + chunk_size = max(1, self.sample_rate * 2 // 10) # 100ms + buffer = b"" + pending_chunk: Optional[bytes] = None + resample_state: Any = None + + while True: + timeout = 8.0 if self._cancel_event.is_set() else 20.0 + event = await self._next_event(timeout=timeout) + event_type = str(event.get("type") or "").strip() + + if event_type in {"response.audio.delta", "response.audio.delta.raw"}: + if self._cancel_event.is_set(): + continue + + pcm = self._decode_audio_event(event) + if not pcm: + continue + pcm, resample_state = self._resample_if_needed(pcm, resample_state) + if not pcm: + continue + + buffer += pcm + while len(buffer) >= chunk_size: + audio_chunk = buffer[:chunk_size] + buffer = buffer[chunk_size:] + if pending_chunk is not None: + yield TTSChunk( + audio=pending_chunk, + sample_rate=self.sample_rate, + is_final=False, + ) + pending_chunk = audio_chunk + continue + + if event_type == "response.done": + break + + if event_type == "error": + raise RuntimeError(self._format_error_event(event)) + + if event_type == "__close__": + reason = str(event.get("reason") or "unknown") + raise RuntimeError(f"DashScope TTS websocket closed unexpectedly: {reason}") + + if self._cancel_event.is_set(): + return + + if pending_chunk is not None: + if buffer: + yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False) + pending_chunk = None + else: + yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True) + pending_chunk = None + + if buffer: + yield TTSChunk(audio=buffer, sample_rate=self.sample_rate, is_final=True) + + async def cancel(self) -> None: + self._cancel_event.set() + if self.mode == "commit": + await self._clear_appended_text() + return + + if not self._client: + return + cancel_fn = ( + getattr(self._client, "cancel_response", None) + or getattr(self._client, "cancel", None) + ) + if callable(cancel_fn): + try: + await asyncio.to_thread(cancel_fn) + except Exception as exc: + logger.debug(f"DashScope cancel failed: {exc}") + + async def _wait_for_session_ready(self) -> None: + try: + while True: + event = await self._next_event(timeout=8.0) + event_type = str(event.get("type") or "").strip() + if event_type in {"session.updated", "session.open"}: + return + if event_type == "error": + raise RuntimeError(self._format_error_event(event)) + except asyncio.TimeoutError: + logger.debug("DashScope session update event timeout; continuing with active websocket") + + async def _clear_appended_text(self) -> None: + if self.mode != "commit": + return + if not self._client: + return + clear_fn = getattr(self._client, "clear_appended_text", None) + if callable(clear_fn): + try: + await asyncio.to_thread(clear_fn) + except Exception as exc: + logger.debug(f"DashScope clear_appended_text failed: {exc}") + + async def _next_event(self, timeout: float) -> Dict[str, Any]: + event = await asyncio.wait_for(self._event_queue.get(), timeout=timeout) + if isinstance(event, dict): + return event + return {"type": "raw", "message": str(event)} + + def _drain_event_queue(self) -> None: + while True: + try: + self._event_queue.get_nowait() + except asyncio.QueueEmpty: + break + + def _decode_audio_event(self, event: Dict[str, Any]) -> bytes: + event_type = str(event.get("type") or "") + if event_type == "response.audio.delta.raw": + audio = event.get("audio") + if isinstance(audio, (bytes, bytearray)): + return bytes(audio) + return b"" + + delta = event.get("delta") + if isinstance(delta, str): + try: + return base64.b64decode(delta) + except Exception as exc: + logger.warning(f"Failed to decode DashScope audio delta: {exc}") + return b"" + if isinstance(delta, (bytes, bytearray)): + return bytes(delta) + return b"" + + def _resample_if_needed(self, pcm: bytes, state: Any) -> Tuple[bytes, Any]: + if self.sample_rate == self.PROVIDER_SAMPLE_RATE: + return pcm, state + try: + converted, next_state = audioop.ratecv( + pcm, + 2, # 16-bit PCM + 1, # mono + self.PROVIDER_SAMPLE_RATE, + self.sample_rate, + state, + ) + return converted, next_state + except Exception as exc: + logger.warning(f"DashScope audio resample failed: {exc}; returning original sample rate data") + return pcm, state + + @staticmethod + def _format_error_event(event: Dict[str, Any]) -> str: + err = event.get("error") + if isinstance(err, dict): + code = str(err.get("code") or "").strip() + message = str(err.get("message") or "").strip() + if code and message: + return f"{code}: {message}" + return message or str(err) + return str(err or "DashScope realtime TTS error") diff --git a/engine/tests/test_agent_config.py b/engine/tests/test_agent_config.py index c8698cb..6432581 100644 --- a/engine/tests/test_agent_config.py +++ b/engine/tests/test_agent_config.py @@ -61,6 +61,25 @@ agent: """.strip() +def _dashscope_tts_yaml() -> str: + return _full_agent_yaml().replace( + """ tts: + provider: openai_compatible + api_key: test-tts-key + api_url: https://example-tts.invalid/v1/audio/speech + model: FunAudioLLM/CosyVoice2-0.5B + voice: anna + speed: 1.0 +""", + """ tts: + provider: dashscope + api_key: test-dashscope-key + voice: Cherry + speed: 1.0 +""", + ) + + def test_cli_profile_loads_agent_yaml(monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) config_dir = tmp_path / "config" / "agents" @@ -152,6 +171,28 @@ def test_missing_tts_api_url_fails(monkeypatch, tmp_path): load_settings(argv=["--agent-config", str(file_path)]) +def test_dashscope_tts_allows_default_url_and_model(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "dashscope-tts.yaml" + _write_yaml(file_path, _dashscope_tts_yaml()) + + settings = load_settings(argv=["--agent-config", str(file_path)]) + + assert settings.tts_provider == "dashscope" + assert settings.tts_api_key == "test-dashscope-key" + assert settings.tts_api_url is None + assert settings.tts_model is None + + +def test_dashscope_tts_requires_api_key(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + file_path = tmp_path / "dashscope-tts-missing-key.yaml" + _write_yaml(file_path, _dashscope_tts_yaml().replace(" api_key: test-dashscope-key\n", "")) + + with pytest.raises(ValueError, match="Missing required agent settings in YAML"): + load_settings(argv=["--agent-config", str(file_path)]) + + def test_missing_asr_api_url_fails(monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) file_path = tmp_path / "missing-asr-url.yaml"