add dashscope tts
This commit is contained in:
@@ -37,6 +37,10 @@ Agent 配置路径优先级
|
||||
- Agent 相关配置是严格模式:YAML 缺少必须项会直接报错,不会回退到 `.env` 或代码默认值。
|
||||
- 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`。
|
||||
- `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider`、`api_key`、`api_url`、`model` 配置。
|
||||
- `agent.tts.provider` 现支持 `dashscope`(Realtime 协议,非 OpenAI-compatible);默认 URL 为 `wss://dashscope.aliyuncs.com/api-ws/v1/realtime`,默认模型为 `qwen3-tts-flash-realtime`。
|
||||
- `agent.tts.dashscope_mode`(兼容旧写法 `agent.tts.mode`)支持 `commit | server_commit`,且仅在 `provider=dashscope` 时生效:
|
||||
- `commit`:Engine 先按句切分,再逐句提交给 DashScope。
|
||||
- `server_commit`:Engine 不再逐句切分,由 DashScope 对整段文本自行切分。
|
||||
- 现在支持在 Agent YAML 中配置 `agent.tools`(列表),用于声明运行时可调用工具。
|
||||
- 工具配置示例见 `config/agents/tools.yaml`。
|
||||
|
||||
|
||||
@@ -41,6 +41,8 @@ _AGENT_SECTION_KEY_MAP: Dict[str, Dict[str, str]] = {
|
||||
"api_url": "tts_api_url",
|
||||
"model": "tts_model",
|
||||
"voice": "tts_voice",
|
||||
"dashscope_mode": "tts_mode",
|
||||
"mode": "tts_mode",
|
||||
"speed": "tts_speed",
|
||||
},
|
||||
"asr": {
|
||||
@@ -80,6 +82,7 @@ _AGENT_SETTING_KEYS = {
|
||||
"tts_api_url",
|
||||
"tts_model",
|
||||
"tts_voice",
|
||||
"tts_mode",
|
||||
"tts_speed",
|
||||
"asr_provider",
|
||||
"asr_api_key",
|
||||
@@ -120,7 +123,10 @@ _BASE_REQUIRED_AGENT_SETTING_KEYS = {
|
||||
"barge_in_min_duration_ms",
|
||||
"barge_in_silence_tolerance_ms",
|
||||
}
|
||||
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
_OPENAI_COMPATIBLE_LLM_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
_OPENAI_COMPATIBLE_TTS_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
_DASHSCOPE_TTS_PROVIDERS = {"dashscope"}
|
||||
_OPENAI_COMPATIBLE_ASR_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
|
||||
|
||||
def _normalized_provider(overrides: Dict[str, Any], key: str, default: str) -> str:
|
||||
@@ -285,21 +291,24 @@ def _missing_required_keys(overrides: Dict[str, Any]) -> List[str]:
|
||||
missing.add(key)
|
||||
|
||||
llm_provider = _normalized_provider(overrides, "llm_provider", "openai")
|
||||
if llm_provider in _OPENAI_COMPATIBLE_PROVIDERS or llm_provider == "openai":
|
||||
if llm_provider in _OPENAI_COMPATIBLE_LLM_PROVIDERS or llm_provider == "openai":
|
||||
if "llm_api_key" not in overrides or _is_blank(overrides.get("llm_api_key")):
|
||||
missing.add("llm_api_key")
|
||||
|
||||
tts_provider = _normalized_provider(overrides, "tts_provider", "openai_compatible")
|
||||
if tts_provider in _OPENAI_COMPATIBLE_PROVIDERS:
|
||||
if tts_provider in _OPENAI_COMPATIBLE_TTS_PROVIDERS:
|
||||
if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")):
|
||||
missing.add("tts_api_key")
|
||||
if "tts_api_url" not in overrides or _is_blank(overrides.get("tts_api_url")):
|
||||
missing.add("tts_api_url")
|
||||
if "tts_model" not in overrides or _is_blank(overrides.get("tts_model")):
|
||||
missing.add("tts_model")
|
||||
elif tts_provider in _DASHSCOPE_TTS_PROVIDERS:
|
||||
if "tts_api_key" not in overrides or _is_blank(overrides.get("tts_api_key")):
|
||||
missing.add("tts_api_key")
|
||||
|
||||
asr_provider = _normalized_provider(overrides, "asr_provider", "openai_compatible")
|
||||
if asr_provider in _OPENAI_COMPATIBLE_PROVIDERS:
|
||||
if asr_provider in _OPENAI_COMPATIBLE_ASR_PROVIDERS:
|
||||
if "asr_api_key" not in overrides or _is_blank(overrides.get("asr_api_key")):
|
||||
missing.add("asr_api_key")
|
||||
if "asr_api_url" not in overrides or _is_blank(overrides.get("asr_api_url")):
|
||||
@@ -401,12 +410,16 @@ class Settings(BaseSettings):
|
||||
# TTS Configuration
|
||||
tts_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="TTS provider (edge, openai_compatible, siliconflow)"
|
||||
description="TTS provider (edge, openai_compatible, siliconflow, dashscope)"
|
||||
)
|
||||
tts_api_key: Optional[str] = Field(default=None, description="TTS provider API key")
|
||||
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
|
||||
tts_model: Optional[str] = Field(default=None, description="TTS model name")
|
||||
tts_voice: str = Field(default="anna", description="TTS voice name")
|
||||
tts_mode: str = Field(
|
||||
default="commit",
|
||||
description="DashScope-only TTS mode (commit, server_commit). Ignored for non-dashscope providers."
|
||||
)
|
||||
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
||||
|
||||
# ASR Configuration
|
||||
|
||||
@@ -21,7 +21,12 @@ agent:
|
||||
api_url: https://api.qnaigc.com/v1
|
||||
|
||||
tts:
|
||||
# provider: edge | openai_compatible | siliconflow
|
||||
# provider: edge | openai_compatible | siliconflow | dashscope
|
||||
# dashscope defaults (if omitted):
|
||||
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
# model: qwen3-tts-flash-realtime
|
||||
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
|
||||
# note: dashscope_mode/mode is ONLY used when provider=dashscope.
|
||||
provider: openai_compatible
|
||||
api_key: your_tts_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/speech
|
||||
|
||||
@@ -18,7 +18,12 @@ agent:
|
||||
api_url: https://api.qnaigc.com/v1
|
||||
|
||||
tts:
|
||||
# provider: edge | openai_compatible | siliconflow
|
||||
# provider: edge | openai_compatible | siliconflow | dashscope
|
||||
# dashscope defaults (if omitted):
|
||||
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
# model: qwen3-tts-flash-realtime
|
||||
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
|
||||
# note: dashscope_mode/mode is ONLY used when provider=dashscope.
|
||||
provider: openai_compatible
|
||||
api_key: your_tts_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/speech
|
||||
|
||||
@@ -30,6 +30,7 @@ from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
from services.asr import BufferedASRService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent
|
||||
from services.dashscope_tts import DashScopeTTSService
|
||||
from services.llm import MockLLMService, OpenAILLMService
|
||||
from services.openai_compatible_asr import OpenAICompatibleASRService
|
||||
from services.openai_compatible_tts import OpenAICompatibleTTSService
|
||||
@@ -349,6 +350,21 @@ class DuplexPipeline:
|
||||
if not output_mode:
|
||||
output_mode = "audio" if self._tts_output_enabled() else "text"
|
||||
|
||||
tts_model = str(
|
||||
self._runtime_tts.get("model")
|
||||
or settings.tts_model
|
||||
or (self._default_dashscope_tts_model() if self._is_dashscope_tts_provider(tts_provider) else "")
|
||||
)
|
||||
tts_config = {
|
||||
"enabled": self._tts_output_enabled(),
|
||||
"provider": tts_provider,
|
||||
"model": tts_model,
|
||||
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
|
||||
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
|
||||
}
|
||||
if self._is_dashscope_tts_provider(tts_provider):
|
||||
tts_config["mode"] = self._resolved_dashscope_tts_mode()
|
||||
|
||||
return {
|
||||
"output": {"mode": output_mode},
|
||||
"services": {
|
||||
@@ -363,13 +379,7 @@ class DuplexPipeline:
|
||||
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
|
||||
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
|
||||
},
|
||||
"tts": {
|
||||
"enabled": self._tts_output_enabled(),
|
||||
"provider": tts_provider,
|
||||
"model": str(self._runtime_tts.get("model") or settings.tts_model or ""),
|
||||
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
|
||||
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
|
||||
},
|
||||
"tts": tts_config,
|
||||
},
|
||||
"tools": {
|
||||
"allowlist": self._resolved_tool_allowlist(),
|
||||
@@ -484,6 +494,11 @@ class DuplexPipeline:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
return normalized in {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
|
||||
@staticmethod
|
||||
def _is_dashscope_tts_provider(provider: Any) -> bool:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
return normalized == "dashscope"
|
||||
|
||||
@staticmethod
|
||||
def _is_llm_provider_supported(provider: Any) -> bool:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
@@ -496,6 +511,28 @@ class DuplexPipeline:
|
||||
return "https://api.siliconflow.cn/v1"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _default_dashscope_tts_realtime_url() -> str:
|
||||
return "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
|
||||
@staticmethod
|
||||
def _default_dashscope_tts_model() -> str:
|
||||
return "qwen3-tts-flash-realtime"
|
||||
|
||||
def _resolved_dashscope_tts_mode(self) -> str:
|
||||
raw_mode = str(self._runtime_tts.get("mode") or settings.tts_mode or "commit").strip().lower()
|
||||
if raw_mode in {"commit", "server_commit"}:
|
||||
return raw_mode
|
||||
return "commit"
|
||||
|
||||
def _use_engine_sentence_split_for_tts(self) -> bool:
|
||||
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).strip().lower()
|
||||
if not self._is_dashscope_tts_provider(tts_provider):
|
||||
return True
|
||||
# DashScope commit mode is client-driven and expects engine-side segmentation.
|
||||
# server_commit mode lets DashScope handle segmentation on appended text.
|
||||
return self._resolved_dashscope_tts_mode() != "server_commit"
|
||||
|
||||
def _tts_output_enabled(self) -> bool:
|
||||
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
||||
if enabled is not None:
|
||||
@@ -610,8 +647,26 @@ class DuplexPipeline:
|
||||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||
tts_model = self._runtime_tts.get("model") or settings.tts_model
|
||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||
tts_mode = self._resolved_dashscope_tts_mode()
|
||||
runtime_mode = str(self._runtime_tts.get("mode") or "").strip()
|
||||
if runtime_mode and not self._is_dashscope_tts_provider(tts_provider):
|
||||
logger.warning(
|
||||
"services.tts.mode is DashScope-only and will be ignored "
|
||||
f"for provider={tts_provider}"
|
||||
)
|
||||
|
||||
if self._is_openai_compatible_provider(tts_provider) and tts_api_key:
|
||||
if self._is_dashscope_tts_provider(tts_provider) and tts_api_key:
|
||||
self.tts_service = DashScopeTTSService(
|
||||
api_key=tts_api_key,
|
||||
api_url=tts_api_url or self._default_dashscope_tts_realtime_url(),
|
||||
voice=tts_voice,
|
||||
model=tts_model or self._default_dashscope_tts_model(),
|
||||
mode=str(tts_mode),
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=tts_speed
|
||||
)
|
||||
logger.info("Using DashScope realtime TTS service")
|
||||
elif self._is_openai_compatible_provider(tts_provider) and tts_api_key:
|
||||
self.tts_service = OpenAICompatibleTTSService(
|
||||
api_key=tts_api_key,
|
||||
api_url=tts_api_url,
|
||||
@@ -1379,6 +1434,7 @@ class DuplexPipeline:
|
||||
round_response = ""
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
allow_text_output = True
|
||||
use_engine_sentence_split = self._use_engine_sentence_split_for_tts()
|
||||
|
||||
async for raw_event in self.llm_service.generate_stream(messages):
|
||||
if self._interrupt_event.is_set():
|
||||
@@ -1446,52 +1502,56 @@ class DuplexPipeline:
|
||||
):
|
||||
await self._flush_pending_llm_delta()
|
||||
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
sentence_buffer,
|
||||
end_chars=self._SENTENCE_END_CHARS,
|
||||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||||
closers=self._SENTENCE_CLOSERS,
|
||||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||||
hold_trailing_at_buffer_end=True,
|
||||
force=False,
|
||||
)
|
||||
if not split_result:
|
||||
break
|
||||
sentence, sentence_buffer = split_result
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||||
pending_punctuation = ""
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
if not has_spoken_content(sentence):
|
||||
pending_punctuation = sentence
|
||||
continue
|
||||
|
||||
if self._tts_output_enabled() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
self._start_tts()
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
},
|
||||
priority=10,
|
||||
)
|
||||
first_audio_sent = True
|
||||
|
||||
await self._speak_sentence(
|
||||
sentence,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
if use_engine_sentence_split:
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
sentence_buffer,
|
||||
end_chars=self._SENTENCE_END_CHARS,
|
||||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||||
closers=self._SENTENCE_CLOSERS,
|
||||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||||
hold_trailing_at_buffer_end=True,
|
||||
force=False,
|
||||
)
|
||||
if not split_result:
|
||||
break
|
||||
sentence, sentence_buffer = split_result
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||||
pending_punctuation = ""
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
if not has_spoken_content(sentence):
|
||||
pending_punctuation = sentence
|
||||
continue
|
||||
|
||||
if self._tts_output_enabled() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
self._start_tts()
|
||||
await self._send_event(
|
||||
{
|
||||
**ev(
|
||||
"output.audio.start",
|
||||
trackId=self.track_audio_out,
|
||||
)
|
||||
},
|
||||
priority=10,
|
||||
)
|
||||
first_audio_sent = True
|
||||
|
||||
await self._speak_sentence(
|
||||
sentence,
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
)
|
||||
|
||||
if use_engine_sentence_split:
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
else:
|
||||
remaining_text = sentence_buffer.strip()
|
||||
await self._flush_pending_llm_delta()
|
||||
if (
|
||||
self._tts_output_enabled()
|
||||
|
||||
@@ -27,6 +27,7 @@ aiohttp>=3.9.1
|
||||
|
||||
# AI Services - LLM
|
||||
openai>=1.0.0
|
||||
dashscope>=1.25.11
|
||||
|
||||
# AI Services - TTS
|
||||
edge-tts>=6.1.0
|
||||
|
||||
@@ -13,6 +13,7 @@ from services.base import (
|
||||
BaseTTSService,
|
||||
)
|
||||
from services.llm import OpenAILLMService, MockLLMService
|
||||
from services.dashscope_tts import DashScopeTTSService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
from services.asr import BufferedASRService, MockASRService
|
||||
from services.openai_compatible_asr import OpenAICompatibleASRService, SiliconFlowASRService
|
||||
@@ -33,6 +34,7 @@ __all__ = [
|
||||
"OpenAILLMService",
|
||||
"MockLLMService",
|
||||
# TTS
|
||||
"DashScopeTTSService",
|
||||
"EdgeTTSService",
|
||||
"MockTTSService",
|
||||
# ASR
|
||||
|
||||
352
engine/services/dashscope_tts.py
Normal file
352
engine/services/dashscope_tts.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""DashScope realtime TTS service.
|
||||
|
||||
Implements DashScope's Qwen realtime TTS protocol via the official SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import audioop
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseTTSService, ServiceState, TTSChunk
|
||||
|
||||
try:
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback
|
||||
|
||||
DASHSCOPE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
dashscope = None # type: ignore[assignment]
|
||||
AudioFormat = None # type: ignore[assignment]
|
||||
QwenTtsRealtime = None # type: ignore[assignment]
|
||||
DASHSCOPE_SDK_AVAILABLE = False
|
||||
|
||||
class QwenTtsRealtimeCallback: # type: ignore[no-redef]
|
||||
"""Fallback callback base when DashScope SDK is unavailable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _RealtimeEventCallback(QwenTtsRealtimeCallback):
|
||||
"""Bridge SDK callback events into an asyncio queue."""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, queue: "asyncio.Queue[Dict[str, Any]]"):
|
||||
super().__init__()
|
||||
self._loop = loop
|
||||
self._queue = queue
|
||||
|
||||
def _push(self, event: Dict[str, Any]) -> None:
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
def on_open(self) -> None:
|
||||
self._push({"type": "session.open"})
|
||||
|
||||
def on_close(self, code: int, reason: str) -> None:
|
||||
self._push({"type": "__close__", "code": code, "reason": reason})
|
||||
|
||||
def on_error(self, message: str) -> None:
|
||||
self._push({"type": "error", "error": {"message": str(message)}})
|
||||
|
||||
def on_event(self, event: Any) -> None:
|
||||
if isinstance(event, dict):
|
||||
payload = event
|
||||
elif isinstance(event, str):
|
||||
try:
|
||||
payload = json.loads(event)
|
||||
except json.JSONDecodeError:
|
||||
payload = {"type": "raw", "message": event}
|
||||
else:
|
||||
payload = {"type": "raw", "message": str(event)}
|
||||
self._push(payload)
|
||||
|
||||
def on_data(self, data: bytes) -> None:
|
||||
# Some SDK versions provide audio via on_data directly.
|
||||
self._push({"type": "response.audio.delta.raw", "audio": data})
|
||||
|
||||
|
||||
class DashScopeTTSService(BaseTTSService):
|
||||
"""DashScope realtime TTS service using Qwen Realtime protocol."""
|
||||
|
||||
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
DEFAULT_MODEL = "qwen3-tts-flash-realtime"
|
||||
PROVIDER_SAMPLE_RATE = 24000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "Cherry",
|
||||
model: Optional[str] = None,
|
||||
mode: str = "server_commit",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY") or os.getenv("TTS_API_KEY")
|
||||
self.api_url = (
|
||||
api_url
|
||||
or os.getenv("DASHSCOPE_TTS_API_URL")
|
||||
or os.getenv("TTS_API_URL")
|
||||
or self.DEFAULT_WS_URL
|
||||
)
|
||||
self.model = model or os.getenv("DASHSCOPE_TTS_MODEL") or self.DEFAULT_MODEL
|
||||
|
||||
normalized_mode = str(mode or "").strip().lower()
|
||||
if normalized_mode not in {"server_commit", "commit"}:
|
||||
logger.warning(f"Unknown DashScope mode '{mode}', fallback to server_commit")
|
||||
normalized_mode = "server_commit"
|
||||
self.mode = normalized_mode
|
||||
|
||||
self._client: Optional[Any] = None
|
||||
self._event_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
||||
self._callback: Optional[_RealtimeEventCallback] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._synthesis_lock = asyncio.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
raise RuntimeError("dashscope package not installed; install with `pip install dashscope`")
|
||||
if not self.api_key:
|
||||
raise ValueError("DashScope API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
self._callback = _RealtimeEventCallback(loop=loop, queue=self._event_queue)
|
||||
# The official Python SDK docs set key via global `dashscope.api_key`;
|
||||
# some SDK versions do not accept `api_key=` in QwenTtsRealtime ctor.
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = self.api_key
|
||||
self._client = self._create_realtime_client(self._callback)
|
||||
|
||||
await asyncio.to_thread(self._client.connect)
|
||||
await asyncio.to_thread(
|
||||
self._client.update_session,
|
||||
voice=self.voice,
|
||||
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
|
||||
mode=self.mode,
|
||||
)
|
||||
await self._wait_for_session_ready()
|
||||
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(
|
||||
"DashScope realtime TTS service ready: "
|
||||
f"voice={self.voice}, model={self.model}, mode={self.mode}"
|
||||
)
|
||||
|
||||
def _create_realtime_client(self, callback: _RealtimeEventCallback) -> Any:
|
||||
init_kwargs = {
|
||||
"model": self.model,
|
||||
"callback": callback,
|
||||
"url": self.api_url,
|
||||
}
|
||||
try:
|
||||
return QwenTtsRealtime( # type: ignore[misc]
|
||||
api_key=self.api_key,
|
||||
**init_kwargs,
|
||||
)
|
||||
except TypeError as exc:
|
||||
if "api_key" not in str(exc):
|
||||
raise
|
||||
logger.debug(
|
||||
"QwenTtsRealtime does not support `api_key` ctor arg; "
|
||||
"falling back to global dashscope.api_key auth"
|
||||
)
|
||||
return QwenTtsRealtime(**init_kwargs) # type: ignore[misc]
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self._client:
|
||||
close_fn = getattr(self._client, "close", None)
|
||||
if callable(close_fn):
|
||||
await asyncio.to_thread(close_fn)
|
||||
self._client = None
|
||||
self._drain_event_queue()
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("DashScope realtime TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
audio = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio += chunk.audio
|
||||
return audio
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
if not self._client:
|
||||
raise RuntimeError("DashScope TTS service not connected")
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
async with self._synthesis_lock:
|
||||
self._cancel_event.clear()
|
||||
self._drain_event_queue()
|
||||
|
||||
await self._clear_appended_text()
|
||||
await asyncio.to_thread(self._client.append_text, text)
|
||||
if self.mode == "commit":
|
||||
await asyncio.to_thread(self._client.commit)
|
||||
|
||||
chunk_size = max(1, self.sample_rate * 2 // 10) # 100ms
|
||||
buffer = b""
|
||||
pending_chunk: Optional[bytes] = None
|
||||
resample_state: Any = None
|
||||
|
||||
while True:
|
||||
timeout = 8.0 if self._cancel_event.is_set() else 20.0
|
||||
event = await self._next_event(timeout=timeout)
|
||||
event_type = str(event.get("type") or "").strip()
|
||||
|
||||
if event_type in {"response.audio.delta", "response.audio.delta.raw"}:
|
||||
if self._cancel_event.is_set():
|
||||
continue
|
||||
|
||||
pcm = self._decode_audio_event(event)
|
||||
if not pcm:
|
||||
continue
|
||||
pcm, resample_state = self._resample_if_needed(pcm, resample_state)
|
||||
if not pcm:
|
||||
continue
|
||||
|
||||
buffer += pcm
|
||||
while len(buffer) >= chunk_size:
|
||||
audio_chunk = buffer[:chunk_size]
|
||||
buffer = buffer[chunk_size:]
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False,
|
||||
)
|
||||
pending_chunk = audio_chunk
|
||||
continue
|
||||
|
||||
if event_type == "response.done":
|
||||
break
|
||||
|
||||
if event_type == "error":
|
||||
raise RuntimeError(self._format_error_event(event))
|
||||
|
||||
if event_type == "__close__":
|
||||
reason = str(event.get("reason") or "unknown")
|
||||
raise RuntimeError(f"DashScope TTS websocket closed unexpectedly: {reason}")
|
||||
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
if pending_chunk is not None:
|
||||
if buffer:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
|
||||
pending_chunk = None
|
||||
|
||||
if buffer:
|
||||
yield TTSChunk(audio=buffer, sample_rate=self.sample_rate, is_final=True)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self.mode == "commit":
|
||||
await self._clear_appended_text()
|
||||
return
|
||||
|
||||
if not self._client:
|
||||
return
|
||||
cancel_fn = (
|
||||
getattr(self._client, "cancel_response", None)
|
||||
or getattr(self._client, "cancel", None)
|
||||
)
|
||||
if callable(cancel_fn):
|
||||
try:
|
||||
await asyncio.to_thread(cancel_fn)
|
||||
except Exception as exc:
|
||||
logger.debug(f"DashScope cancel failed: {exc}")
|
||||
|
||||
async def _wait_for_session_ready(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
event = await self._next_event(timeout=8.0)
|
||||
event_type = str(event.get("type") or "").strip()
|
||||
if event_type in {"session.updated", "session.open"}:
|
||||
return
|
||||
if event_type == "error":
|
||||
raise RuntimeError(self._format_error_event(event))
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("DashScope session update event timeout; continuing with active websocket")
|
||||
|
||||
async def _clear_appended_text(self) -> None:
|
||||
if self.mode != "commit":
|
||||
return
|
||||
if not self._client:
|
||||
return
|
||||
clear_fn = getattr(self._client, "clear_appended_text", None)
|
||||
if callable(clear_fn):
|
||||
try:
|
||||
await asyncio.to_thread(clear_fn)
|
||||
except Exception as exc:
|
||||
logger.debug(f"DashScope clear_appended_text failed: {exc}")
|
||||
|
||||
async def _next_event(self, timeout: float) -> Dict[str, Any]:
|
||||
event = await asyncio.wait_for(self._event_queue.get(), timeout=timeout)
|
||||
if isinstance(event, dict):
|
||||
return event
|
||||
return {"type": "raw", "message": str(event)}
|
||||
|
||||
def _drain_event_queue(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
self._event_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
def _decode_audio_event(self, event: Dict[str, Any]) -> bytes:
|
||||
event_type = str(event.get("type") or "")
|
||||
if event_type == "response.audio.delta.raw":
|
||||
audio = event.get("audio")
|
||||
if isinstance(audio, (bytes, bytearray)):
|
||||
return bytes(audio)
|
||||
return b""
|
||||
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str):
|
||||
try:
|
||||
return base64.b64decode(delta)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to decode DashScope audio delta: {exc}")
|
||||
return b""
|
||||
if isinstance(delta, (bytes, bytearray)):
|
||||
return bytes(delta)
|
||||
return b""
|
||||
|
||||
def _resample_if_needed(self, pcm: bytes, state: Any) -> Tuple[bytes, Any]:
|
||||
if self.sample_rate == self.PROVIDER_SAMPLE_RATE:
|
||||
return pcm, state
|
||||
try:
|
||||
converted, next_state = audioop.ratecv(
|
||||
pcm,
|
||||
2, # 16-bit PCM
|
||||
1, # mono
|
||||
self.PROVIDER_SAMPLE_RATE,
|
||||
self.sample_rate,
|
||||
state,
|
||||
)
|
||||
return converted, next_state
|
||||
except Exception as exc:
|
||||
logger.warning(f"DashScope audio resample failed: {exc}; returning original sample rate data")
|
||||
return pcm, state
|
||||
|
||||
@staticmethod
|
||||
def _format_error_event(event: Dict[str, Any]) -> str:
|
||||
err = event.get("error")
|
||||
if isinstance(err, dict):
|
||||
code = str(err.get("code") or "").strip()
|
||||
message = str(err.get("message") or "").strip()
|
||||
if code and message:
|
||||
return f"{code}: {message}"
|
||||
return message or str(err)
|
||||
return str(err or "DashScope realtime TTS error")
|
||||
@@ -61,6 +61,25 @@ agent:
|
||||
""".strip()
|
||||
|
||||
|
||||
def _dashscope_tts_yaml() -> str:
|
||||
return _full_agent_yaml().replace(
|
||||
""" tts:
|
||||
provider: openai_compatible
|
||||
api_key: test-tts-key
|
||||
api_url: https://example-tts.invalid/v1/audio/speech
|
||||
model: FunAudioLLM/CosyVoice2-0.5B
|
||||
voice: anna
|
||||
speed: 1.0
|
||||
""",
|
||||
""" tts:
|
||||
provider: dashscope
|
||||
api_key: test-dashscope-key
|
||||
voice: Cherry
|
||||
speed: 1.0
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def test_cli_profile_loads_agent_yaml(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
config_dir = tmp_path / "config" / "agents"
|
||||
@@ -152,6 +171,28 @@ def test_missing_tts_api_url_fails(monkeypatch, tmp_path):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_dashscope_tts_allows_default_url_and_model(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "dashscope-tts.yaml"
|
||||
_write_yaml(file_path, _dashscope_tts_yaml())
|
||||
|
||||
settings = load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
assert settings.tts_provider == "dashscope"
|
||||
assert settings.tts_api_key == "test-dashscope-key"
|
||||
assert settings.tts_api_url is None
|
||||
assert settings.tts_model is None
|
||||
|
||||
|
||||
def test_dashscope_tts_requires_api_key(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "dashscope-tts-missing-key.yaml"
|
||||
_write_yaml(file_path, _dashscope_tts_yaml().replace(" api_key: test-dashscope-key\n", ""))
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required agent settings in YAML"):
|
||||
load_settings(argv=["--agent-config", str(file_path)])
|
||||
|
||||
|
||||
def test_missing_asr_api_url_fails(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path / "missing-asr-url.yaml"
|
||||
|
||||
Reference in New Issue
Block a user