diff --git a/docs/content/customization/asr.md b/docs/content/customization/asr.md index 2c11a87..e0b7e3a 100644 --- a/docs/content/customization/asr.md +++ b/docs/content/customization/asr.md @@ -5,7 +5,7 @@ ## 模式 - `offline`:引擎本地缓冲音频后触发识别(适用于 OpenAI-compatible / SiliconFlow)。 -- `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR)。 +- `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR、Volcengine BigASR)。 ## 配置项 @@ -14,6 +14,8 @@ | ASR 引擎 | 选择语音识别服务提供商 | | 模型 | 识别模型名称 | | `enable_interim` | 是否开启离线 ASR 中间结果(默认 `false`,仅离线模式生效) | +| `app_id` / `resource_id` | Volcengine 等厂商的应用标识与资源标识 | +| `request_params` | 厂商原生请求参数透传,例如 `end_window_size`、`force_to_speech_time`、`context` | | 语言 | 中文/英文/多语言 | | 热词 | 提升特定词汇识别准确率 | | 标点与规范化 | 是否自动补全标点、文本规范化 | @@ -23,7 +25,7 @@ - 客服场景建议开启热词并维护业务词表 - 多语言场景建议按会话入口显式指定语言 - 对延迟敏感场景优先选择流式识别模型 -- 当前支持提供商:`openai_compatible`、`siliconflow`、`dashscope`、`buffered`(回退) +- 当前支持提供商:`openai_compatible`、`siliconflow`、`dashscope`、`volcengine`、`buffered`(回退) ## 相关文档 diff --git a/engine/adapters/control_plane/backend.py b/engine/adapters/control_plane/backend.py index 09f145d..9f8914d 100644 --- a/engine/adapters/control_plane/backend.py +++ b/engine/adapters/control_plane/backend.py @@ -230,6 +230,14 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter): tts_runtime["baseUrl"] = cls._as_str(tts.get("api_url")) if cls._as_str(tts.get("voice")): tts_runtime["voice"] = cls._as_str(tts.get("voice")) + if cls._as_str(tts.get("app_id")): + tts_runtime["appId"] = cls._as_str(tts.get("app_id")) + if cls._as_str(tts.get("resource_id")): + tts_runtime["resourceId"] = cls._as_str(tts.get("resource_id")) + if cls._as_str(tts.get("cluster")): + tts_runtime["cluster"] = cls._as_str(tts.get("cluster")) + if cls._as_str(tts.get("uid")): + tts_runtime["uid"] = cls._as_str(tts.get("uid")) if tts.get("speed") is not None: tts_runtime["speed"] = tts.get("speed") dashscope_mode = cls._as_str(tts.get("dashscope_mode")) or cls._as_str(tts.get("mode")) @@ -249,6 +257,16 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter): asr_runtime["apiKey"] = cls._as_str(asr.get("api_key")) if cls._as_str(asr.get("api_url")): asr_runtime["baseUrl"] = cls._as_str(asr.get("api_url")) + if cls._as_str(asr.get("app_id")): + asr_runtime["appId"] = cls._as_str(asr.get("app_id")) + if cls._as_str(asr.get("resource_id")): + asr_runtime["resourceId"] = cls._as_str(asr.get("resource_id")) + if cls._as_str(asr.get("cluster")): + asr_runtime["cluster"] = cls._as_str(asr.get("cluster")) + if cls._as_str(asr.get("uid")): + asr_runtime["uid"] = cls._as_str(asr.get("uid")) + if isinstance(asr.get("request_params"), dict): + asr_runtime["requestParams"] = dict(asr.get("request_params") or {}) if asr.get("enable_interim") is not None: asr_runtime["enableInterim"] = asr.get("enable_interim") if asr.get("interim_interval_ms") is not None: diff --git a/engine/app/config.py b/engine/app/config.py index 1d8f47b..8edf7ce 100644 --- a/engine/app/config.py +++ b/engine/app/config.py @@ -71,11 +71,15 @@ class Settings(BaseSettings): # TTS Configuration tts_provider: str = Field( default="openai_compatible", - description="TTS provider (openai_compatible, siliconflow, dashscope)" + description="TTS provider (openai_compatible, siliconflow, dashscope, volcengine)" ) tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL") tts_model: Optional[str] = Field(default=None, description="TTS model name") tts_voice: str = Field(default="anna", description="TTS voice name") + tts_app_id: Optional[str] = Field(default=None, description="Provider-specific TTS app ID") + tts_resource_id: Optional[str] = Field(default=None, description="Provider-specific TTS resource ID") + tts_cluster: Optional[str] = Field(default=None, description="Provider-specific TTS cluster") + tts_uid: Optional[str] = Field(default=None, description="Provider-specific TTS user ID") tts_mode: str = Field( default="commit", description="DashScope-only TTS mode (commit, server_commit). Ignored for non-dashscope providers." @@ -85,10 +89,18 @@ class Settings(BaseSettings): # ASR Configuration asr_provider: str = Field( default="openai_compatible", - description="ASR provider (openai_compatible, buffered, siliconflow, dashscope)" + description="ASR provider (openai_compatible, buffered, siliconflow, dashscope, volcengine)" ) asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL") asr_model: Optional[str] = Field(default=None, description="ASR model name") + asr_app_id: Optional[str] = Field(default=None, description="Provider-specific ASR app ID") + asr_resource_id: Optional[str] = Field(default=None, description="Provider-specific ASR resource ID") + asr_cluster: Optional[str] = Field(default=None, description="Provider-specific ASR cluster") + asr_uid: Optional[str] = Field(default=None, description="Provider-specific ASR user ID") + asr_request_params_json: Optional[str] = Field( + default=None, + description="Provider-specific ASR request params as JSON string" + ) asr_enable_interim: bool = Field(default=False, description="Enable interim transcripts for offline ASR") asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms") asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result") diff --git a/engine/config/agents/example.yaml b/engine/config/agents/example.yaml index d4d6d5d..e68b6f3 100644 --- a/engine/config/agents/example.yaml +++ b/engine/config/agents/example.yaml @@ -21,12 +21,17 @@ agent: api_url: https://api.qnaigc.com/v1 tts: - # provider: openai_compatible | siliconflow | dashscope + # provider: openai_compatible | siliconflow | dashscope | volcengine # dashscope defaults (if omitted): # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # model: qwen3-tts-flash-realtime # dashscope_mode: commit (engine splits) | server_commit (dashscope splits) # note: dashscope_mode/mode is ONLY used when provider=dashscope. + # volcengine defaults (if omitted): + # api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional + # resource_id: seed-tts-2.0 + # app_id: your volcengine app key + # api_key: your volcengine access key provider: openai_compatible api_key: your_tts_api_key api_url: https://api.siliconflow.cn/v1/audio/speech @@ -35,11 +40,21 @@ agent: speed: 1.0 asr: - # provider: buffered | openai_compatible | siliconflow | dashscope + # provider: buffered | openai_compatible | siliconflow | dashscope | volcengine # dashscope defaults (if omitted): # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # model: qwen3-asr-flash-realtime # note: dashscope uses streaming ASR mode (chunk-by-chunk). + # volcengine defaults (if omitted): + # api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel + # model: bigmodel + # resource_id: volc.bigasr.sauc.duration + # app_id: your volcengine app key + # api_key: your volcengine access key + # request_params: + # end_window_size: 800 + # force_to_speech_time: 1000 + # note: volcengine uses streaming ASR mode (chunk-by-chunk). provider: openai_compatible api_key: you_asr_api_key api_url: https://api.siliconflow.cn/v1/audio/transcriptions diff --git a/engine/config/agents/tools.yaml b/engine/config/agents/tools.yaml index 26b43bf..11cd7c3 100644 --- a/engine/config/agents/tools.yaml +++ b/engine/config/agents/tools.yaml @@ -18,12 +18,17 @@ agent: api_url: https://api.qnaigc.com/v1 tts: - # provider: openai_compatible | siliconflow | dashscope + # provider: openai_compatible | siliconflow | dashscope | volcengine # dashscope defaults (if omitted): # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # model: qwen3-tts-flash-realtime # dashscope_mode: commit (engine splits) | server_commit (dashscope splits) # note: dashscope_mode/mode is ONLY used when provider=dashscope. + # volcengine defaults (if omitted): + # api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional + # resource_id: seed-tts-2.0 + # app_id: your volcengine app key + # api_key: your volcengine access key provider: openai_compatible api_key: your_tts_api_key api_url: https://api.siliconflow.cn/v1/audio/speech @@ -32,11 +37,21 @@ agent: speed: 1.0 asr: - # provider: buffered | openai_compatible | siliconflow | dashscope + # provider: buffered | openai_compatible | siliconflow | dashscope | volcengine # dashscope defaults (if omitted): # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # model: qwen3-asr-flash-realtime # note: dashscope uses streaming ASR mode (chunk-by-chunk). + # volcengine defaults (if omitted): + # api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel + # model: bigmodel + # resource_id: volc.bigasr.sauc.duration + # app_id: your volcengine app key + # api_key: your volcengine access key + # request_params: + # end_window_size: 800 + # force_to_speech_time: 1000 + # note: volcengine uses streaming ASR mode (chunk-by-chunk). provider: openai_compatible api_key: your_asr_api_key api_url: https://api.siliconflow.cn/v1/audio/transcriptions diff --git a/engine/docs/extension_ports.md b/engine/docs/extension_ports.md index 36e2aac..c0f65f6 100644 --- a/engine/docs/extension_ports.md +++ b/engine/docs/extension_ports.md @@ -36,10 +36,10 @@ This document defines the draft port set used to keep core runtime extensible. - supported providers: `openai`, `openai_compatible`, `openai-compatible`, `siliconflow` - fallback: `MockLLMService` - TTS: - - supported providers: `dashscope`, `openai_compatible`, `openai-compatible`, `siliconflow` + - supported providers: `dashscope`, `volcengine`, `openai_compatible`, `openai-compatible`, `siliconflow` - fallback: `MockTTSService` - ASR: - - supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope` + - supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope`, `volcengine` - fallback: `BufferedASRService` ## Notes diff --git a/engine/providers/asr/__init__.py b/engine/providers/asr/__init__.py index 5e659be..5e5dc29 100644 --- a/engine/providers/asr/__init__.py +++ b/engine/providers/asr/__init__.py @@ -3,6 +3,7 @@ from providers.asr.buffered import BufferedASRService, MockASRService from providers.asr.dashscope import DashScopeRealtimeASRService from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService +from providers.asr.volcengine import VolcengineRealtimeASRService __all__ = [ "BufferedASRService", @@ -10,4 +11,5 @@ __all__ = [ "DashScopeRealtimeASRService", "OpenAICompatibleASRService", "SiliconFlowASRService", + "VolcengineRealtimeASRService", ] diff --git a/engine/providers/asr/volcengine.py b/engine/providers/asr/volcengine.py new file mode 100644 index 0000000..1f7c18e --- /dev/null +++ b/engine/providers/asr/volcengine.py @@ -0,0 +1,666 @@ +"""Volcengine realtime ASR service. + +Supports both: +- Volcengine Edge Gateway realtime transcription websocket, and +- Volcengine BigASR Seed websocket at openspeech.bytedance.com/api/v3/sauc/bigmodel. +""" + +from __future__ import annotations + +import asyncio +import base64 +import gzip +import json +import os +import uuid +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +import aiohttp +from loguru import logger + +from providers.common.base import ASRResult, BaseASRService, ServiceState + +VolcengineASRProtocol = Literal["gateway", "seed"] + + +class VolcengineRealtimeASRService(BaseASRService): + """Realtime streaming ASR backed by Volcengine websocket APIs.""" + + DEFAULT_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel" + DEFAULT_GATEWAY_WS_URL = "wss://ai-gateway.vei.volces.com/v1/realtime" + DEFAULT_MODEL = "bigmodel" + DEFAULT_FINAL_TIMEOUT_MS = 1200 + DEFAULT_SEED_RESOURCE_ID = "volc.bigasr.sauc.duration" + _SEED_FRAME_MS = 100 + _SEED_PROTOCOL_VERSION = 0b0001 + _SEED_FULL_CLIENT_REQUEST = 0b0001 + _SEED_AUDIO_ONLY_REQUEST = 0b0010 + _SEED_FULL_SERVER_RESPONSE = 0b1001 + _SEED_SERVER_ACK = 0b1011 + _SEED_SERVER_ERROR_RESPONSE = 0b1111 + _SEED_NO_SEQUENCE = 0b0000 + _SEED_POS_SEQUENCE = 0b0001 + _SEED_NEG_WITH_SEQUENCE = 0b0011 + _SEED_NO_SERIALIZATION = 0b0000 + _SEED_JSON = 0b0001 + _SEED_NO_COMPRESSION = 0b0000 + _SEED_GZIP = 0b0001 + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + model: Optional[str] = None, + sample_rate: int = 16000, + language: str = "auto", + app_id: Optional[str] = None, + resource_id: Optional[str] = None, + uid: Optional[str] = None, + request_params: Optional[Dict[str, Any]] = None, + on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None, + ) -> None: + super().__init__(sample_rate=sample_rate, language=language) + self.mode = "streaming" + self.api_key = api_key or os.getenv("VOLCENGINE_ASR_API_KEY") or os.getenv("ASR_API_KEY") + self.model = str(model or os.getenv("VOLCENGINE_ASR_MODEL") or self.DEFAULT_MODEL).strip() + raw_api_url = api_url or os.getenv("VOLCENGINE_ASR_API_URL") or self.DEFAULT_WS_URL + self.protocol = self._detect_protocol(raw_api_url) + self.api_url = self._resolve_api_url(raw_api_url, self.model, self.protocol) + self.app_id = app_id or os.getenv("VOLCENGINE_ASR_APP_ID") or os.getenv("ASR_APP_ID") + self.resource_id = ( + resource_id + or os.getenv("VOLCENGINE_ASR_RESOURCE_ID") + or (self.DEFAULT_SEED_RESOURCE_ID if self.protocol == "seed" else None) + ) + self.uid = uid or os.getenv("VOLCENGINE_ASR_UID") + self.request_params = self._load_request_params(request_params) + self.on_transcript = on_transcript + + self._session: Optional[aiohttp.ClientSession] = None + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._reader_task: Optional[asyncio.Task[None]] = None + + self._running = False + self._session_ready = asyncio.Event() + self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue() + self._final_queue: "asyncio.Queue[str]" = asyncio.Queue() + + self._utterance_active = False + self._audio_sent_in_utterance = False + self._last_interim_text = "" + self._last_error: Optional[str] = None + + self._seed_audio_buffer = bytearray() + self._seed_sequence = 1 + self._seed_request_id: Optional[str] = None + self._seed_frame_bytes = max(2, int((self.sample_rate * self._SEED_FRAME_MS / 1000) * 2)) + + @classmethod + def _detect_protocol(cls, api_url: str) -> VolcengineASRProtocol: + parsed = urlparse(str(api_url or "").strip()) + host = parsed.netloc.lower() + path = parsed.path.lower() + if "openspeech.bytedance.com" in host and "/api/v3/sauc/bigmodel" in path: + return "seed" + return "gateway" + + @classmethod + def _resolve_api_url(cls, api_url: str, model: str, protocol: VolcengineASRProtocol) -> str: + raw = str(api_url or "").strip() + if not raw: + raw = cls.DEFAULT_WS_URL if protocol == "seed" else cls.DEFAULT_GATEWAY_WS_URL + if protocol != "gateway": + return raw + + parsed = urlparse(raw) + query = dict(parse_qsl(parsed.query, keep_blank_values=True)) + query.setdefault("model", model or cls.DEFAULT_MODEL) + return urlunparse(parsed._replace(query=urlencode(query))) + + @staticmethod + def _load_request_params(request_params: Optional[Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(request_params, dict): + return dict(request_params) + raw = os.getenv("VOLCENGINE_ASR_REQUEST_PARAMS_JSON", "").strip() + if not raw: + return {} + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + logger.warning("Ignoring invalid VOLCENGINE_ASR_REQUEST_PARAMS_JSON") + return {} + if isinstance(parsed, dict): + return parsed + return {} + + async def connect(self) -> None: + if not self.api_key: + raise ValueError("Volcengine ASR API key not provided. Configure agent.asr.api_key in YAML.") + + timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15) + self._session = aiohttp.ClientSession(timeout=timeout) + self._running = True + + if self.protocol == "gateway": + await self._connect_gateway() + logger.info( + "Volcengine gateway ASR connected: model={}, sample_rate={}, url={}", + self.model, + self.sample_rate, + self.api_url, + ) + else: + if not self.app_id: + raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.") + logger.info( + "Volcengine BigASR Seed ready: model={}, sample_rate={}, resource_id={}", + self.model, + self.sample_rate, + self.resource_id, + ) + + self.state = ServiceState.CONNECTED + + async def disconnect(self) -> None: + self._running = False + self._utterance_active = False + self._audio_sent_in_utterance = False + self._session_ready.clear() + self._seed_audio_buffer = bytearray() + self._drain_queue(self._final_queue) + self._drain_queue(self._transcript_queue) + + await self._close_ws() + + if self._session is not None: + await self._session.close() + self._session = None + + self.state = ServiceState.DISCONNECTED + logger.info("Volcengine ASR disconnected") + + async def begin_utterance(self) -> None: + self.clear_utterance() + if self.protocol == "seed": + await self._open_seed_stream() + self._utterance_active = True + + async def send_audio(self, audio: bytes) -> None: + if not audio: + return + + if self.protocol == "seed": + await self._send_seed_audio(audio) + return + + if not self._ws: + raise RuntimeError("Volcengine ASR websocket is not connected") + if not self._utterance_active: + self._utterance_active = True + + await self._ws.send_json( + { + "type": "input_audio_buffer.append", + "audio": base64.b64encode(audio).decode("ascii"), + } + ) + self._audio_sent_in_utterance = True + + async def end_utterance(self) -> None: + if not self._utterance_active: + return + + if self.protocol == "seed": + await self._end_seed_utterance() + return + + if not self._ws or not self._audio_sent_in_utterance: + return + await self._ws.send_json({"type": "input_audio_buffer.commit"}) + self._utterance_active = False + + async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str: + if not self._audio_sent_in_utterance: + return "" + + timeout_sec = max(0.05, float(timeout_ms) / 1000.0) + try: + return str(await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) or "").strip() + except asyncio.TimeoutError: + logger.debug("Volcengine ASR final timeout ({}ms), fallback to last interim", timeout_ms) + return str(self._last_interim_text or "").strip() + finally: + if self.protocol == "seed": + await self._close_ws() + + def clear_utterance(self) -> None: + self._utterance_active = False + self._audio_sent_in_utterance = False + self._last_interim_text = "" + self._last_error = None + self._seed_audio_buffer = bytearray() + self._seed_sequence = 1 + self._seed_request_id = None + self._drain_queue(self._final_queue) + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + while self._running: + try: + yield await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + async def _connect_gateway(self) -> None: + assert self._session is not None + headers = {"Authorization": f"Bearer {self.api_key}"} + if self.resource_id: + headers["X-Api-Resource-Id"] = self.resource_id + + self._ws = await self._session.ws_connect(self.api_url, headers=headers, heartbeat=20) + self._reader_task = asyncio.create_task(self._reader_loop()) + await self._configure_gateway_session() + + async def _configure_gateway_session(self) -> None: + if not self._ws: + raise RuntimeError("Volcengine ASR websocket is not initialized") + + session_payload: Dict[str, Any] = { + "input_audio_format": "pcm", + "input_audio_codec": "raw", + "input_audio_sample_rate": self.sample_rate, + "input_audio_bits": 16, + "input_audio_channel": 1, + "result_type": 0, + "input_audio_transcription": { + "model": self.model, + }, + } + + await self._ws.send_json( + { + "type": "transcription_session.update", + "session": session_payload, + } + ) + + try: + await asyncio.wait_for(self._session_ready.wait(), timeout=8.0) + except asyncio.TimeoutError as exc: + raise RuntimeError("Volcengine ASR session update timeout") from exc + + async def _open_seed_stream(self) -> None: + if not self._session: + raise RuntimeError("Volcengine ASR session is not initialized") + + await self._close_ws() + self._seed_request_id = uuid.uuid4().hex + headers = self._build_seed_headers(self._seed_request_id) + self._ws = await self._session.ws_connect( + self.api_url, + headers=headers, + heartbeat=20, + max_msg_size=1_000_000_000, + ) + self._reader_task = asyncio.create_task(self._reader_loop()) + await self._ws.send_bytes(self._build_seed_start_request()) + + async def _send_seed_audio(self, audio: bytes) -> None: + if not self._utterance_active: + await self.begin_utterance() + if not self._ws: + raise RuntimeError("Volcengine BigASR websocket is not connected") + + self._seed_audio_buffer.extend(audio) + while len(self._seed_audio_buffer) >= self._seed_frame_bytes: + chunk = bytes(self._seed_audio_buffer[: self._seed_frame_bytes]) + del self._seed_audio_buffer[: self._seed_frame_bytes] + self._seed_sequence += 1 + await self._ws.send_bytes(self._build_seed_audio_request(chunk, sequence=self._seed_sequence)) + self._audio_sent_in_utterance = True + + async def _end_seed_utterance(self) -> None: + if not self._ws: + return + if not self._audio_sent_in_utterance and not self._seed_audio_buffer: + self._utterance_active = False + return + + final_chunk = bytes(self._seed_audio_buffer) + self._seed_audio_buffer = bytearray() + self._seed_sequence += 1 + await self._ws.send_bytes( + self._build_seed_audio_request(final_chunk, sequence=-self._seed_sequence, is_last=True) + ) + self._audio_sent_in_utterance = True + self._utterance_active = False + + async def _close_ws(self) -> None: + reader_task = self._reader_task + ws = self._ws + self._reader_task = None + self._ws = None + + if reader_task: + reader_task.cancel() + try: + await reader_task + except asyncio.CancelledError: + pass + + if ws is not None: + await ws.close() + + async def _reader_loop(self) -> None: + ws = self._ws + if ws is None: + return + + try: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if self.protocol == "gateway": + self._handle_gateway_event(msg.data) + else: + self._handle_seed_text(msg.data) + continue + if msg.type == aiohttp.WSMsgType.BINARY: + if self.protocol == "seed": + self._handle_seed_binary(msg.data) + continue + if msg.type == aiohttp.WSMsgType.ERROR: + self._last_error = str(ws.exception()) + logger.error("Volcengine ASR websocket error: {}", self._last_error) + break + if msg.type in {aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE}: + break + except asyncio.CancelledError: + raise + except Exception as exc: + self._last_error = str(exc) + logger.error("Volcengine ASR reader loop failed: {}", exc) + finally: + if self._ws is ws: + self._ws = None + + def _handle_gateway_event(self, message: str) -> None: + payload = self._coerce_event(message) + event_type = str(payload.get("type") or "").strip() + if not event_type: + return + + if event_type in {"transcription_session.created", "transcription_session.updated"}: + self._session_ready.set() + return + + if event_type == "error": + self._last_error = self._extract_text(payload, ("message", "error")) + logger.error("Volcengine ASR server error: {}", self._last_error or "unknown") + return + + if event_type.endswith(".failed"): + self._last_error = self._extract_text(payload, ("message", "error", "transcript")) + logger.error("Volcengine ASR failed event: {}", self._last_error or event_type) + return + + if event_type == "conversation.item.input_audio_transcription.result": + transcript = self._extract_text(payload, ("transcript", "result")) + self._emit_transcript_sync(transcript, is_final=False) + return + + if event_type == "conversation.item.input_audio_transcription.delta": + transcript = self._extract_text(payload, ("delta",)) + self._emit_transcript_sync(transcript, is_final=False) + return + + if event_type == "conversation.item.input_audio_transcription.completed": + transcript = self._extract_text(payload, ("transcript", "result")) + self._emit_transcript_sync(transcript, is_final=True) + + def _handle_seed_text(self, message: str) -> None: + payload = self._coerce_event(message) + if payload.get("type") == "error": + self._last_error = self._extract_text(payload, ("message", "error")) + logger.error("Volcengine BigASR error: {}", self._last_error or "unknown") + + def _handle_seed_binary(self, message: bytes) -> None: + payload = self._parse_seed_response(message) + if payload.get("code"): + self._last_error = self._extract_text(payload, ("payload_msg",)) + logger.error("Volcengine BigASR server error: {}", self._last_error or payload["code"]) + return + + body = payload.get("payload_msg") + if not isinstance(body, dict): + return + result = body.get("result") + if not isinstance(result, dict): + return + + text = str(result.get("text") or "").strip() + if not text: + return + + utterances = result.get("utterances") + if not isinstance(utterances, list) or not utterances: + return + first_utterance = utterances[0] if isinstance(utterances[0], dict) else {} + is_final = self._coerce_bool(first_utterance.get("definite")) is True + self._emit_transcript_sync(text, is_final=is_final) + + def _emit_transcript_sync(self, text: str, *, is_final: bool) -> None: + cleaned = str(text or "").strip() + if not cleaned: + return + + if not is_final: + self._last_interim_text = cleaned + else: + self._last_interim_text = "" + + result = ASRResult(text=cleaned, is_final=is_final) + try: + self._transcript_queue.put_nowait(result) + except asyncio.QueueFull: + logger.debug("Volcengine ASR transcript queue full; dropping transcript") + + if is_final: + try: + self._final_queue.put_nowait(cleaned) + except asyncio.QueueFull: + logger.debug("Volcengine ASR final queue full; dropping transcript") + + if self.on_transcript: + asyncio.create_task(self.on_transcript(cleaned, is_final)) + + def _build_seed_headers(self, request_id: str) -> Dict[str, str]: + if not self.app_id: + raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.") + if not self.api_key: + raise ValueError("Volcengine ASR api_key not provided. Configure agent.asr.api_key in YAML.") + + return { + "X-Api-App-Key": str(self.app_id), + "X-Api-Access-Key": str(self.api_key), + "X-Api-Resource-Id": str(self.resource_id or self.DEFAULT_SEED_RESOURCE_ID), + "X-Api-Request-Id": str(request_id), + } + + def _build_seed_start_payload(self) -> Dict[str, Any]: + user_payload: Dict[str, Any] = {"uid": str(self.uid or self._seed_request_id or self.app_id or uuid.uuid4().hex)} + audio_payload: Dict[str, Any] = { + "format": "pcm", + "rate": self.sample_rate, + "bits": 16, + "channels": 1, + "codec": "raw", + } + if self.language and self.language != "auto": + audio_payload["language"] = self.language + + request_payload: Dict[str, Any] = { + "model_name": self.model or self.DEFAULT_MODEL, + "enable_itn": False, + "enable_punc": True, + "enable_ddc": False, + "show_utterance": True, + "result_type": "single", + "vad_segment_duration": 3000, + "end_window_size": 500, + "force_to_speech_time": 1000, + } + + extra = dict(self.request_params) + user_payload.update(self._as_dict(extra.pop("user", None))) + audio_payload.update(self._as_dict(extra.pop("audio", None))) + request_payload.update(self._as_dict(extra.pop("request", None))) + request_payload.update(extra) + + return { + "user": user_payload, + "audio": audio_payload, + "request": request_payload, + } + + def _build_seed_start_request(self) -> bytes: + payload = gzip.compress(json.dumps(self._build_seed_start_payload()).encode("utf-8")) + frame = bytearray( + self._build_seed_header( + message_type=self._SEED_FULL_CLIENT_REQUEST, + message_type_specific_flags=self._SEED_POS_SEQUENCE, + ) + ) + frame.extend((1).to_bytes(4, "big", signed=True)) + frame.extend(len(payload).to_bytes(4, "big")) + frame.extend(payload) + return bytes(frame) + + def _build_seed_audio_request(self, chunk: bytes, *, sequence: int, is_last: bool = False) -> bytes: + payload = gzip.compress(chunk) + frame = bytearray( + self._build_seed_header( + message_type=self._SEED_AUDIO_ONLY_REQUEST, + message_type_specific_flags=self._SEED_NEG_WITH_SEQUENCE if is_last else self._SEED_POS_SEQUENCE, + ) + ) + frame.extend(int(sequence).to_bytes(4, "big", signed=True)) + frame.extend(len(payload).to_bytes(4, "big")) + frame.extend(payload) + return bytes(frame) + + @classmethod + def _build_seed_header( + cls, + *, + message_type: int, + message_type_specific_flags: int, + serial_method: int = _SEED_JSON, + compression_type: int = _SEED_GZIP, + reserved_data: int = 0x00, + ) -> bytes: + header = bytearray() + header.append((cls._SEED_PROTOCOL_VERSION << 4) | 0b0001) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + return bytes(header) + + @classmethod + def _parse_seed_response(cls, response: bytes) -> Dict[str, Any]: + header_size = response[0] & 0x0F + message_type = response[1] >> 4 + message_type_specific_flags = response[1] & 0x0F + serialization_method = response[2] >> 4 + compression_type = response[2] & 0x0F + payload = response[header_size * 4 :] + + result: Dict[str, Any] = {"is_last_package": False} + payload_message: Any = None + + if message_type_specific_flags & 0x01: + result["payload_sequence"] = int.from_bytes(payload[:4], "big", signed=True) + payload = payload[4:] + + if message_type_specific_flags & 0x02: + result["is_last_package"] = True + + if message_type == cls._SEED_FULL_SERVER_RESPONSE: + result["payload_size"] = int.from_bytes(payload[:4], "big", signed=True) + payload_message = payload[4:] + elif message_type == cls._SEED_SERVER_ACK: + result["seq"] = int.from_bytes(payload[:4], "big", signed=True) + if len(payload) >= 8: + result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False) + payload_message = payload[8:] + elif message_type == cls._SEED_SERVER_ERROR_RESPONSE: + result["code"] = int.from_bytes(payload[:4], "big", signed=False) + result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False) + payload_message = payload[8:] + + if payload_message is None: + return result + if compression_type == cls._SEED_GZIP: + payload_message = gzip.decompress(payload_message) + if serialization_method == cls._SEED_JSON: + payload_message = json.loads(payload_message.decode("utf-8")) + elif serialization_method != cls._SEED_NO_SERIALIZATION: + payload_message = payload_message.decode("utf-8") + + result["payload_msg"] = payload_message + return result + + @staticmethod + def _coerce_event(message: Any) -> Dict[str, Any]: + if isinstance(message, dict): + return message + if isinstance(message, str): + try: + loaded = json.loads(message) + if isinstance(loaded, dict): + return loaded + except json.JSONDecodeError: + return {"type": "raw", "message": message} + return {"type": "raw", "message": str(message)} + + @staticmethod + def _extract_text(payload: Dict[str, Any], keys: tuple[str, ...]) -> str: + for key in keys: + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(value, dict): + for nested_key in ("message", "text", "transcript", "result", "delta"): + nested = value.get(nested_key) + if isinstance(nested, str) and nested.strip(): + return nested.strip() + return "" + + @staticmethod + def _coerce_bool(value: Any) -> Optional[bool]: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + return None + + @staticmethod + def _as_dict(value: Any) -> Dict[str, Any]: + if isinstance(value, dict): + return dict(value) + return {} + + @staticmethod + def _drain_queue(queue: "asyncio.Queue[Any]") -> None: + while True: + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break diff --git a/engine/providers/factory/default.py b/engine/providers/factory/default.py index 3d51fe9..de72af6 100644 --- a/engine/providers/factory/default.py +++ b/engine/providers/factory/default.py @@ -17,14 +17,17 @@ from runtime.ports import ( ) from providers.asr.buffered import BufferedASRService from providers.asr.dashscope import DashScopeRealtimeASRService +from providers.asr.volcengine import VolcengineRealtimeASRService from providers.tts.dashscope import DashScopeTTSService from providers.llm.openai import MockLLMService, OpenAILLMService from providers.asr.openai_compatible import OpenAICompatibleASRService from providers.tts.openai_compatible import OpenAICompatibleTTSService from providers.tts.mock import MockTTSService +from providers.tts.volcengine import VolcengineTTSService _OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} _DASHSCOPE_PROVIDERS = {"dashscope"} +_VOLCENGINE_PROVIDERS = {"volcengine"} _SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS} @@ -37,6 +40,10 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): _DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime" _DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B" _DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" + _DEFAULT_VOLCENGINE_TTS_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional" + _DEFAULT_VOLCENGINE_TTS_RESOURCE_ID = "seed-tts-2.0" + _DEFAULT_VOLCENGINE_ASR_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel" + _DEFAULT_VOLCENGINE_ASR_MODEL = "bigmodel" @staticmethod def _normalize_provider(provider: Any) -> str: @@ -81,6 +88,19 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): speed=spec.speed, ) + if provider in _VOLCENGINE_PROVIDERS and spec.api_key: + return VolcengineTTSService( + api_key=spec.api_key, + api_url=spec.api_url or self._DEFAULT_VOLCENGINE_TTS_URL, + voice=spec.voice, + model=spec.model, + app_id=spec.app_id, + resource_id=spec.resource_id or self._DEFAULT_VOLCENGINE_TTS_RESOURCE_ID, + uid=spec.uid, + sample_rate=spec.sample_rate, + speed=spec.speed, + ) + if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleTTSService( api_key=spec.api_key, @@ -110,6 +130,20 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): on_transcript=spec.on_transcript, ) + if provider in _VOLCENGINE_PROVIDERS and spec.api_key: + return VolcengineRealtimeASRService( + api_key=spec.api_key, + api_url=spec.api_url or self._DEFAULT_VOLCENGINE_ASR_REALTIME_URL, + model=spec.model or self._DEFAULT_VOLCENGINE_ASR_MODEL, + sample_rate=spec.sample_rate, + language=spec.language, + app_id=spec.app_id, + resource_id=spec.resource_id, + uid=spec.uid, + request_params=spec.request_params, + on_transcript=spec.on_transcript, + ) + if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleASRService( api_key=spec.api_key, diff --git a/engine/providers/tts/__init__.py b/engine/providers/tts/__init__.py index 531ecfa..b2b237a 100644 --- a/engine/providers/tts/__init__.py +++ b/engine/providers/tts/__init__.py @@ -1 +1,5 @@ """TTS providers.""" + +from providers.tts.volcengine import VolcengineTTSService + +__all__ = ["VolcengineTTSService"] diff --git a/engine/providers/tts/volcengine.py b/engine/providers/tts/volcengine.py new file mode 100644 index 0000000..d7502a1 --- /dev/null +++ b/engine/providers/tts/volcengine.py @@ -0,0 +1,219 @@ +"""Volcengine TTS service. + +Uses Volcengine's unidirectional HTTP streaming TTS API and adapts streamed +base64 audio chunks into engine-native ``TTSChunk`` events. +""" + +from __future__ import annotations + +import asyncio +import base64 +import codecs +import json +import os +import uuid +from typing import Any, AsyncIterator, Optional + +import aiohttp +from loguru import logger + +from providers.common.base import BaseTTSService, ServiceState, TTSChunk + + +class VolcengineTTSService(BaseTTSService): + """Streaming TTS adapter for Volcengine's HTTP v3 API.""" + + DEFAULT_API_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional" + DEFAULT_RESOURCE_ID = "seed-tts-2.0" + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + voice: str = "zh_female_shuangkuaisisi_moon_bigtts", + model: Optional[str] = None, + app_id: Optional[str] = None, + resource_id: Optional[str] = None, + uid: Optional[str] = None, + sample_rate: int = 16000, + speed: float = 1.0, + ) -> None: + super().__init__(voice=voice, sample_rate=sample_rate, speed=speed) + self.api_key = api_key or os.getenv("VOLCENGINE_TTS_API_KEY") or os.getenv("TTS_API_KEY") + self.api_url = api_url or os.getenv("VOLCENGINE_TTS_API_URL") or self.DEFAULT_API_URL + self.model = str(model or os.getenv("VOLCENGINE_TTS_MODEL") or "").strip() or None + self.app_id = app_id or os.getenv("VOLCENGINE_TTS_APP_ID") or os.getenv("TTS_APP_ID") + self.resource_id = resource_id or os.getenv("VOLCENGINE_TTS_RESOURCE_ID") or self.DEFAULT_RESOURCE_ID + self.uid = uid or os.getenv("VOLCENGINE_TTS_UID") + + self._session: Optional[aiohttp.ClientSession] = None + self._cancel_event = asyncio.Event() + self._synthesis_lock = asyncio.Lock() + self._pending_audio: list[bytes] = [] + + async def connect(self) -> None: + if not self.api_key: + raise ValueError("Volcengine TTS API key not provided. Configure agent.tts.api_key in YAML.") + if not self.app_id: + raise ValueError("Volcengine TTS app_id not provided. Configure agent.tts.app_id in YAML.") + + timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15) + self._session = aiohttp.ClientSession(timeout=timeout) + self.state = ServiceState.CONNECTED + logger.info( + "Volcengine TTS service ready: speaker={}, sample_rate={}, resource_id={}", + self.voice, + self.sample_rate, + self.resource_id, + ) + + async def disconnect(self) -> None: + self._cancel_event.set() + if self._session is not None: + await self._session.close() + self._session = None + self.state = ServiceState.DISCONNECTED + logger.info("Volcengine TTS service disconnected") + + async def synthesize(self, text: str) -> bytes: + audio = b"" + async for chunk in self.synthesize_stream(text): + audio += chunk.audio + return audio + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + if not self._session: + raise RuntimeError("Volcengine TTS service not connected") + if not text.strip(): + return + + async with self._synthesis_lock: + self._cancel_event.clear() + + headers = { + "Content-Type": "application/json", + "X-Api-App-Key": str(self.app_id), + "X-Api-Access-Key": str(self.api_key), + "X-Api-Resource-Id": str(self.resource_id), + "X-Api-Request-Id": str(uuid.uuid4()), + } + payload = { + "user": { + "uid": str(self.uid or self.app_id), + }, + "req_params": { + "text": text, + "speaker": self.voice, + "audio_params": { + "format": "pcm", + "sample_rate": self.sample_rate, + "speech_rate": self._speech_rate_percent(self.speed), + }, + }, + } + if self.model: + payload["req_params"]["model"] = self.model + + chunk_size = max(1, self.sample_rate * 2 // 10) + audio_buffer = b"" + pending_chunk: Optional[bytes] = None + + try: + async with self._session.post(self.api_url, headers=headers, json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Volcengine TTS error {response.status}: {error_text}") + + async for audio_bytes in self._iter_audio_bytes(response): + if self._cancel_event.is_set(): + logger.info("Volcengine TTS synthesis cancelled") + return + + audio_buffer += audio_bytes + while len(audio_buffer) >= chunk_size: + emitted = audio_buffer[:chunk_size] + audio_buffer = audio_buffer[chunk_size:] + if pending_chunk is not None: + yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False) + pending_chunk = emitted + + if self._cancel_event.is_set(): + return + + if pending_chunk is not None: + if audio_buffer: + yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False) + pending_chunk = None + else: + yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True) + pending_chunk = None + + if audio_buffer: + yield TTSChunk(audio=audio_buffer, sample_rate=self.sample_rate, is_final=True) + + except asyncio.CancelledError: + logger.info("Volcengine TTS synthesis cancelled via asyncio") + raise + except Exception as exc: + logger.error("Volcengine TTS synthesis error: {}", exc) + raise + + async def cancel(self) -> None: + self._cancel_event.set() + + async def _iter_audio_bytes(self, response: aiohttp.ClientResponse) -> AsyncIterator[bytes]: + decoder = json.JSONDecoder() + utf8_decoder = codecs.getincrementaldecoder("utf-8")() + text_buffer = "" + self._pending_audio.clear() + + async for raw_chunk in response.content.iter_any(): + text_buffer += utf8_decoder.decode(raw_chunk) + text_buffer = self._yield_audio_payloads(decoder, text_buffer) + while self._pending_audio: + yield self._pending_audio.pop(0) + + text_buffer += utf8_decoder.decode(b"", final=True) + text_buffer = self._yield_audio_payloads(decoder, text_buffer) + while self._pending_audio: + yield self._pending_audio.pop(0) + + def _yield_audio_payloads(self, decoder: json.JSONDecoder, text_buffer: str) -> str: + while True: + stripped = text_buffer.lstrip() + if not stripped: + return "" + if len(stripped) != len(text_buffer): + text_buffer = stripped + + try: + payload, idx = decoder.raw_decode(text_buffer) + except json.JSONDecodeError: + return text_buffer + + text_buffer = text_buffer[idx:] + audio = self._extract_audio_bytes(payload) + if audio: + self._pending_audio.append(audio) + + def _extract_audio_bytes(self, payload: Any) -> bytes: + if not isinstance(payload, dict): + return b"" + + code = payload.get("code") + if code not in (None, 0, 20000000): + message = str(payload.get("message") or "unknown error") + raise RuntimeError(f"Volcengine TTS stream error {code}: {message}") + + encoded = payload.get("data") + if isinstance(encoded, str) and encoded.strip(): + try: + return base64.b64decode(encoded) + except Exception as exc: + logger.warning("Failed to decode Volcengine TTS audio chunk: {}", exc) + return b"" + + @staticmethod + def _speech_rate_percent(speed: float) -> int: + clamped = max(0.5, min(2.0, float(speed or 1.0))) + return int(round((clamped - 1.0) * 100)) diff --git a/engine/runtime/pipeline/duplex.py b/engine/runtime/pipeline/duplex.py index cbfabb3..dcf198f 100644 --- a/engine/runtime/pipeline/duplex.py +++ b/engine/runtime/pipeline/duplex.py @@ -793,6 +793,23 @@ class DuplexPipeline: return False return None + @staticmethod + def _coerce_json_object(value: Any) -> Optional[Dict[str, Any]]: + if isinstance(value, dict): + return dict(value) + if isinstance(value, str): + raw = value.strip() + if not raw: + return None + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + logger.warning("Ignoring invalid JSON object config: {}", raw[:120]) + return None + if isinstance(parsed, dict): + return parsed + return None + @staticmethod def _is_dashscope_tts_provider(provider: Any) -> bool: normalized = str(provider or "").strip().lower() @@ -804,7 +821,7 @@ class DuplexPipeline: if normalized_mode in {"offline", "streaming"}: return normalized_mode # type: ignore[return-value] normalized_provider = str(provider or "").strip().lower() - if normalized_provider == "dashscope": + if normalized_provider in {"dashscope", "volcengine"}: return "streaming" return "offline" @@ -963,6 +980,10 @@ class DuplexPipeline: tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url tts_voice = self._runtime_tts.get("voice") or settings.tts_voice tts_model = self._runtime_tts.get("model") or settings.tts_model + tts_app_id = self._runtime_tts.get("appId") or settings.tts_app_id + tts_resource_id = self._runtime_tts.get("resourceId") or settings.tts_resource_id + tts_cluster = self._runtime_tts.get("cluster") or settings.tts_cluster + tts_uid = self._runtime_tts.get("uid") or settings.tts_uid tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) tts_mode = self._resolved_dashscope_tts_mode() runtime_mode = str(self._runtime_tts.get("mode") or "").strip() @@ -978,6 +999,10 @@ class DuplexPipeline: api_url=str(tts_api_url).strip() if tts_api_url else None, voice=str(tts_voice), model=str(tts_model).strip() if tts_model else None, + app_id=str(tts_app_id).strip() if tts_app_id else None, + resource_id=str(tts_resource_id).strip() if tts_resource_id else None, + cluster=str(tts_cluster).strip() if tts_cluster else None, + uid=str(tts_uid).strip() if tts_uid else None, sample_rate=settings.sample_rate, speed=tts_speed, mode=str(tts_mode), @@ -1006,6 +1031,13 @@ class DuplexPipeline: asr_api_key = self._runtime_asr.get("apiKey") asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url asr_model = self._runtime_asr.get("model") or settings.asr_model + asr_app_id = self._runtime_asr.get("appId") or settings.asr_app_id + asr_resource_id = self._runtime_asr.get("resourceId") or settings.asr_resource_id + asr_cluster = self._runtime_asr.get("cluster") or settings.asr_cluster + asr_uid = self._runtime_asr.get("uid") or settings.asr_uid + asr_request_params = self._coerce_json_object(self._runtime_asr.get("requestParams")) + if asr_request_params is None: + asr_request_params = self._coerce_json_object(settings.asr_request_params_json) asr_enable_interim = self._coerce_bool(self._runtime_asr.get("enableInterim")) if asr_enable_interim is None: asr_enable_interim = bool(settings.asr_enable_interim) @@ -1022,6 +1054,11 @@ class DuplexPipeline: api_key=str(asr_api_key).strip() if asr_api_key else None, api_url=str(asr_api_url).strip() if asr_api_url else None, model=str(asr_model).strip() if asr_model else None, + app_id=str(asr_app_id).strip() if asr_app_id else None, + resource_id=str(asr_resource_id).strip() if asr_resource_id else None, + cluster=str(asr_cluster).strip() if asr_cluster else None, + uid=str(asr_uid).strip() if asr_uid else None, + request_params=asr_request_params, enable_interim=asr_enable_interim, interim_interval_ms=asr_interim_interval, min_audio_for_interim_ms=asr_min_audio_ms, diff --git a/engine/runtime/ports/asr.py b/engine/runtime/ports/asr.py index f3be1d1..b1310b1 100644 --- a/engine/runtime/ports/asr.py +++ b/engine/runtime/ports/asr.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import AsyncIterator, Awaitable, Callable, Literal, Optional, Protocol +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional, Protocol from providers.common.base import ASRResult @@ -22,6 +22,11 @@ class ASRServiceSpec: api_key: Optional[str] = None api_url: Optional[str] = None model: Optional[str] = None + app_id: Optional[str] = None + resource_id: Optional[str] = None + cluster: Optional[str] = None + uid: Optional[str] = None + request_params: Optional[Dict[str, Any]] = None enable_interim: bool = False interim_interval_ms: int = 500 min_audio_for_interim_ms: int = 300 diff --git a/engine/runtime/ports/tts.py b/engine/runtime/ports/tts.py index 523dc3c..a98e17d 100644 --- a/engine/runtime/ports/tts.py +++ b/engine/runtime/ports/tts.py @@ -19,6 +19,10 @@ class TTSServiceSpec: api_key: Optional[str] = None api_url: Optional[str] = None model: Optional[str] = None + app_id: Optional[str] = None + resource_id: Optional[str] = None + cluster: Optional[str] = None + uid: Optional[str] = None mode: str = "commit" diff --git a/engine/tests/test_asr_factory_modes.py b/engine/tests/test_asr_factory_modes.py index 5d3d436..5cd78f8 100644 --- a/engine/tests/test_asr_factory_modes.py +++ b/engine/tests/test_asr_factory_modes.py @@ -1,6 +1,7 @@ from providers.asr.buffered import BufferedASRService from providers.asr.dashscope import DashScopeRealtimeASRService from providers.asr.openai_compatible import OpenAICompatibleASRService +from providers.asr.volcengine import VolcengineRealtimeASRService from providers.factory.default import DefaultRealtimeServiceFactory from runtime.ports import ASRServiceSpec @@ -35,6 +36,29 @@ def test_create_asr_service_openai_compatible_returns_offline_provider(): assert service.enable_interim is False +def test_create_asr_service_volcengine_returns_streaming_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_asr_service( + ASRServiceSpec( + provider="volcengine", + mode="streaming", + sample_rate=16000, + api_key="test-key", + api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + model="bigmodel", + app_id="app-1", + uid="caller-1", + request_params={"end_window_size": 800}, + ) + ) + assert isinstance(service, VolcengineRealtimeASRService) + assert service.mode == "streaming" + assert service.protocol == "seed" + assert service.app_id == "app-1" + assert service.uid == "caller-1" + assert service.request_params["end_window_size"] == 800 + + def test_create_asr_service_fallback_buffered_for_unsupported_provider(): factory = DefaultRealtimeServiceFactory() service = factory.create_asr_service( diff --git a/engine/tests/test_backend_adapters.py b/engine/tests/test_backend_adapters.py index e4faf81..9cce105 100644 --- a/engine/tests/test_backend_adapters.py +++ b/engine/tests/test_backend_adapters.py @@ -227,6 +227,62 @@ async def test_with_backend_url_uses_backend_for_assistant_config(monkeypatch, t assert payload["assistant"]["systemPrompt"] == "backend prompt" +def test_translate_agent_schema_maps_volcengine_fields(): + payload = { + "agent": { + "tts": { + "provider": "volcengine", + "api_key": "tts-key", + "api_url": "https://openspeech.bytedance.com/api/v3/tts/unidirectional", + "app_id": "app-123", + "resource_id": "seed-tts-2.0", + "uid": "caller-1", + "voice": "zh_female_shuangkuaisisi_moon_bigtts", + "speed": 1.1, + }, + "asr": { + "provider": "volcengine", + "api_key": "asr-key", + "api_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + "model": "bigmodel", + "app_id": "app-123", + "resource_id": "volc.bigasr.sauc.duration", + "uid": "caller-1", + "request_params": { + "end_window_size": 800, + "force_to_speech_time": 1000, + }, + }, + } + } + + translated = LocalYamlAssistantConfigAdapter._translate_agent_schema("assistant_demo", payload) + assert translated is not None + assert translated["services"]["tts"] == { + "provider": "volcengine", + "apiKey": "tts-key", + "baseUrl": "https://openspeech.bytedance.com/api/v3/tts/unidirectional", + "voice": "zh_female_shuangkuaisisi_moon_bigtts", + "appId": "app-123", + "resourceId": "seed-tts-2.0", + "uid": "caller-1", + "speed": 1.1, + } + assert translated["services"]["asr"] == { + "provider": "volcengine", + "model": "bigmodel", + "apiKey": "asr-key", + "baseUrl": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + "appId": "app-123", + "resourceId": "volc.bigasr.sauc.duration", + "uid": "caller-1", + "requestParams": { + "end_window_size": 800, + "force_to_speech_time": 1000, + }, + } + + @pytest.mark.asyncio async def test_backend_mode_disabled_uses_local_assistant_config_even_with_url(monkeypatch, tmp_path): class _FailIfCalledClientSession: diff --git a/engine/tests/test_tts_factory_modes.py b/engine/tests/test_tts_factory_modes.py new file mode 100644 index 0000000..987fc10 --- /dev/null +++ b/engine/tests/test_tts_factory_modes.py @@ -0,0 +1,45 @@ +from providers.factory.default import DefaultRealtimeServiceFactory +from providers.tts.mock import MockTTSService +from providers.tts.openai_compatible import OpenAICompatibleTTSService +from providers.tts.volcengine import VolcengineTTSService +from runtime.ports import TTSServiceSpec + + +def test_create_tts_service_volcengine_returns_native_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_tts_service( + TTSServiceSpec( + provider="volcengine", + api_key="test-key", + app_id="app-1", + resource_id="seed-tts-2.0", + voice="zh_female_shuangkuaisisi_moon_bigtts", + sample_rate=16000, + ) + ) + assert isinstance(service, VolcengineTTSService) + + +def test_create_tts_service_openai_compatible_returns_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_tts_service( + TTSServiceSpec( + provider="openai_compatible", + api_key="test-key", + voice="anna", + sample_rate=16000, + ) + ) + assert isinstance(service, OpenAICompatibleTTSService) + + +def test_create_tts_service_fallbacks_to_mock_without_key(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_tts_service( + TTSServiceSpec( + provider="volcengine", + voice="anna", + sample_rate=16000, + ) + ) + assert isinstance(service, MockTTSService) diff --git a/engine/tests/test_volcengine_asr_provider.py b/engine/tests/test_volcengine_asr_provider.py new file mode 100644 index 0000000..c5756c0 --- /dev/null +++ b/engine/tests/test_volcengine_asr_provider.py @@ -0,0 +1,86 @@ +import gzip +import json + +from providers.asr.volcengine import VolcengineRealtimeASRService + + +def test_volcengine_seed_protocol_defaults_and_headers(): + service = VolcengineRealtimeASRService( + api_key="access-token", + api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + app_id="app-1", + uid="caller-1", + ) + + assert service.protocol == "seed" + assert service.resource_id == "volc.bigasr.sauc.duration" + + headers = service._build_seed_headers("req-1") + assert headers == { + "X-Api-App-Key": "app-1", + "X-Api-Access-Key": "access-token", + "X-Api-Resource-Id": "volc.bigasr.sauc.duration", + "X-Api-Request-Id": "req-1", + } + + +def test_volcengine_seed_start_payload_merges_request_params(): + service = VolcengineRealtimeASRService( + api_key="access-token", + api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + app_id="app-1", + uid="caller-1", + language="zh-CN", + request_params={ + "request": { + "end_window_size": 800, + "force_to_speech_time": 1000, + "context": "{\"hotwords\":[{\"word\":\"doubao\"}]}", + }, + "audio": {"codec": "raw"}, + }, + ) + + payload = service._build_seed_start_payload() + assert payload["user"] == {"uid": "caller-1"} + assert payload["audio"] == { + "format": "pcm", + "rate": 16000, + "bits": 16, + "channels": 1, + "codec": "raw", + "language": "zh-CN", + } + assert payload["request"]["model_name"] == "bigmodel" + assert payload["request"]["end_window_size"] == 800 + assert payload["request"]["force_to_speech_time"] == 1000 + assert payload["request"]["context"] == "{\"hotwords\":[{\"word\":\"doubao\"}]}" + + +def test_volcengine_seed_start_request_encodes_gzip_json_payload(): + service = VolcengineRealtimeASRService( + api_key="access-token", + api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + app_id="app-1", + uid="caller-1", + ) + + frame = service._build_seed_start_request() + assert frame[0] == 0x11 + assert frame[1] == 0x11 + + payload_length = int.from_bytes(frame[8:12], "big") + payload = json.loads(gzip.decompress(frame[12 : 12 + payload_length]).decode("utf-8")) + assert payload["user"]["uid"] == "caller-1" + assert payload["request"]["model_name"] == "bigmodel" + + +def test_volcengine_gateway_protocol_keeps_model_query(): + service = VolcengineRealtimeASRService( + api_key="access-token", + api_url="wss://ai-gateway.vei.volces.com/v1/realtime", + model="bigmodel", + ) + + assert service.protocol == "gateway" + assert service.api_url == "wss://ai-gateway.vei.volces.com/v1/realtime?model=bigmodel"