Add Volcengine support for TTS and ASR services

- Introduced Volcengine as a new provider for both TTS and ASR services.
- Updated configuration files to include Volcengine-specific parameters such as app_id, resource_id, and uid.
- Enhanced the ASR service to support streaming mode with Volcengine's API.
- Modified existing tests to validate the integration of Volcengine services.
- Updated documentation to reflect the addition of Volcengine as a supported provider for TTS and ASR.
- Refactored service factory to accommodate Volcengine alongside existing providers.
This commit is contained in:
Xin Wang
2026-03-08 23:09:50 +08:00
parent 3604db21eb
commit aeeeee20d1
18 changed files with 1256 additions and 12 deletions

View File

@@ -5,7 +5,7 @@
## 模式 ## 模式
- `offline`:引擎本地缓冲音频后触发识别(适用于 OpenAI-compatible / SiliconFlow - `offline`:引擎本地缓冲音频后触发识别(适用于 OpenAI-compatible / SiliconFlow
- `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR - `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR、Volcengine BigASR)。
## 配置项 ## 配置项
@@ -14,6 +14,8 @@
| ASR 引擎 | 选择语音识别服务提供商 | | ASR 引擎 | 选择语音识别服务提供商 |
| 模型 | 识别模型名称 | | 模型 | 识别模型名称 |
| `enable_interim` | 是否开启离线 ASR 中间结果(默认 `false`,仅离线模式生效) | | `enable_interim` | 是否开启离线 ASR 中间结果(默认 `false`,仅离线模式生效) |
| `app_id` / `resource_id` | Volcengine 等厂商的应用标识与资源标识 |
| `request_params` | 厂商原生请求参数透传,例如 `end_window_size``force_to_speech_time``context` |
| 语言 | 中文/英文/多语言 | | 语言 | 中文/英文/多语言 |
| 热词 | 提升特定词汇识别准确率 | | 热词 | 提升特定词汇识别准确率 |
| 标点与规范化 | 是否自动补全标点、文本规范化 | | 标点与规范化 | 是否自动补全标点、文本规范化 |
@@ -23,7 +25,7 @@
- 客服场景建议开启热词并维护业务词表 - 客服场景建议开启热词并维护业务词表
- 多语言场景建议按会话入口显式指定语言 - 多语言场景建议按会话入口显式指定语言
- 对延迟敏感场景优先选择流式识别模型 - 对延迟敏感场景优先选择流式识别模型
- 当前支持提供商:`openai_compatible``siliconflow``dashscope``buffered`(回退) - 当前支持提供商:`openai_compatible``siliconflow``dashscope``volcengine``buffered`(回退)
## 相关文档 ## 相关文档

View File

@@ -230,6 +230,14 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter):
tts_runtime["baseUrl"] = cls._as_str(tts.get("api_url")) tts_runtime["baseUrl"] = cls._as_str(tts.get("api_url"))
if cls._as_str(tts.get("voice")): if cls._as_str(tts.get("voice")):
tts_runtime["voice"] = cls._as_str(tts.get("voice")) tts_runtime["voice"] = cls._as_str(tts.get("voice"))
if cls._as_str(tts.get("app_id")):
tts_runtime["appId"] = cls._as_str(tts.get("app_id"))
if cls._as_str(tts.get("resource_id")):
tts_runtime["resourceId"] = cls._as_str(tts.get("resource_id"))
if cls._as_str(tts.get("cluster")):
tts_runtime["cluster"] = cls._as_str(tts.get("cluster"))
if cls._as_str(tts.get("uid")):
tts_runtime["uid"] = cls._as_str(tts.get("uid"))
if tts.get("speed") is not None: if tts.get("speed") is not None:
tts_runtime["speed"] = tts.get("speed") tts_runtime["speed"] = tts.get("speed")
dashscope_mode = cls._as_str(tts.get("dashscope_mode")) or cls._as_str(tts.get("mode")) dashscope_mode = cls._as_str(tts.get("dashscope_mode")) or cls._as_str(tts.get("mode"))
@@ -249,6 +257,16 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter):
asr_runtime["apiKey"] = cls._as_str(asr.get("api_key")) asr_runtime["apiKey"] = cls._as_str(asr.get("api_key"))
if cls._as_str(asr.get("api_url")): if cls._as_str(asr.get("api_url")):
asr_runtime["baseUrl"] = cls._as_str(asr.get("api_url")) asr_runtime["baseUrl"] = cls._as_str(asr.get("api_url"))
if cls._as_str(asr.get("app_id")):
asr_runtime["appId"] = cls._as_str(asr.get("app_id"))
if cls._as_str(asr.get("resource_id")):
asr_runtime["resourceId"] = cls._as_str(asr.get("resource_id"))
if cls._as_str(asr.get("cluster")):
asr_runtime["cluster"] = cls._as_str(asr.get("cluster"))
if cls._as_str(asr.get("uid")):
asr_runtime["uid"] = cls._as_str(asr.get("uid"))
if isinstance(asr.get("request_params"), dict):
asr_runtime["requestParams"] = dict(asr.get("request_params") or {})
if asr.get("enable_interim") is not None: if asr.get("enable_interim") is not None:
asr_runtime["enableInterim"] = asr.get("enable_interim") asr_runtime["enableInterim"] = asr.get("enable_interim")
if asr.get("interim_interval_ms") is not None: if asr.get("interim_interval_ms") is not None:

View File

@@ -71,11 +71,15 @@ class Settings(BaseSettings):
# TTS Configuration # TTS Configuration
tts_provider: str = Field( tts_provider: str = Field(
default="openai_compatible", default="openai_compatible",
description="TTS provider (openai_compatible, siliconflow, dashscope)" description="TTS provider (openai_compatible, siliconflow, dashscope, volcengine)"
) )
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL") tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
tts_model: Optional[str] = Field(default=None, description="TTS model name") tts_model: Optional[str] = Field(default=None, description="TTS model name")
tts_voice: str = Field(default="anna", description="TTS voice name") tts_voice: str = Field(default="anna", description="TTS voice name")
tts_app_id: Optional[str] = Field(default=None, description="Provider-specific TTS app ID")
tts_resource_id: Optional[str] = Field(default=None, description="Provider-specific TTS resource ID")
tts_cluster: Optional[str] = Field(default=None, description="Provider-specific TTS cluster")
tts_uid: Optional[str] = Field(default=None, description="Provider-specific TTS user ID")
tts_mode: str = Field( tts_mode: str = Field(
default="commit", default="commit",
description="DashScope-only TTS mode (commit, server_commit). Ignored for non-dashscope providers." description="DashScope-only TTS mode (commit, server_commit). Ignored for non-dashscope providers."
@@ -85,10 +89,18 @@ class Settings(BaseSettings):
# ASR Configuration # ASR Configuration
asr_provider: str = Field( asr_provider: str = Field(
default="openai_compatible", default="openai_compatible",
description="ASR provider (openai_compatible, buffered, siliconflow, dashscope)" description="ASR provider (openai_compatible, buffered, siliconflow, dashscope, volcengine)"
) )
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL") asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
asr_model: Optional[str] = Field(default=None, description="ASR model name") asr_model: Optional[str] = Field(default=None, description="ASR model name")
asr_app_id: Optional[str] = Field(default=None, description="Provider-specific ASR app ID")
asr_resource_id: Optional[str] = Field(default=None, description="Provider-specific ASR resource ID")
asr_cluster: Optional[str] = Field(default=None, description="Provider-specific ASR cluster")
asr_uid: Optional[str] = Field(default=None, description="Provider-specific ASR user ID")
asr_request_params_json: Optional[str] = Field(
default=None,
description="Provider-specific ASR request params as JSON string"
)
asr_enable_interim: bool = Field(default=False, description="Enable interim transcripts for offline ASR") asr_enable_interim: bool = Field(default=False, description="Enable interim transcripts for offline ASR")
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms") asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result") asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")

View File

@@ -21,12 +21,17 @@ agent:
api_url: https://api.qnaigc.com/v1 api_url: https://api.qnaigc.com/v1
tts: tts:
# provider: openai_compatible | siliconflow | dashscope # provider: openai_compatible | siliconflow | dashscope | volcengine
# dashscope defaults (if omitted): # dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-tts-flash-realtime # model: qwen3-tts-flash-realtime
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits) # dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
# note: dashscope_mode/mode is ONLY used when provider=dashscope. # note: dashscope_mode/mode is ONLY used when provider=dashscope.
# volcengine defaults (if omitted):
# api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional
# resource_id: seed-tts-2.0
# app_id: your volcengine app key
# api_key: your volcengine access key
provider: openai_compatible provider: openai_compatible
api_key: your_tts_api_key api_key: your_tts_api_key
api_url: https://api.siliconflow.cn/v1/audio/speech api_url: https://api.siliconflow.cn/v1/audio/speech
@@ -35,11 +40,21 @@ agent:
speed: 1.0 speed: 1.0
asr: asr:
# provider: buffered | openai_compatible | siliconflow | dashscope # provider: buffered | openai_compatible | siliconflow | dashscope | volcengine
# dashscope defaults (if omitted): # dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-asr-flash-realtime # model: qwen3-asr-flash-realtime
# note: dashscope uses streaming ASR mode (chunk-by-chunk). # note: dashscope uses streaming ASR mode (chunk-by-chunk).
# volcengine defaults (if omitted):
# api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
# model: bigmodel
# resource_id: volc.bigasr.sauc.duration
# app_id: your volcengine app key
# api_key: your volcengine access key
# request_params:
# end_window_size: 800
# force_to_speech_time: 1000
# note: volcengine uses streaming ASR mode (chunk-by-chunk).
provider: openai_compatible provider: openai_compatible
api_key: you_asr_api_key api_key: you_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions api_url: https://api.siliconflow.cn/v1/audio/transcriptions

View File

@@ -18,12 +18,17 @@ agent:
api_url: https://api.qnaigc.com/v1 api_url: https://api.qnaigc.com/v1
tts: tts:
# provider: openai_compatible | siliconflow | dashscope # provider: openai_compatible | siliconflow | dashscope | volcengine
# dashscope defaults (if omitted): # dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-tts-flash-realtime # model: qwen3-tts-flash-realtime
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits) # dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
# note: dashscope_mode/mode is ONLY used when provider=dashscope. # note: dashscope_mode/mode is ONLY used when provider=dashscope.
# volcengine defaults (if omitted):
# api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional
# resource_id: seed-tts-2.0
# app_id: your volcengine app key
# api_key: your volcengine access key
provider: openai_compatible provider: openai_compatible
api_key: your_tts_api_key api_key: your_tts_api_key
api_url: https://api.siliconflow.cn/v1/audio/speech api_url: https://api.siliconflow.cn/v1/audio/speech
@@ -32,11 +37,21 @@ agent:
speed: 1.0 speed: 1.0
asr: asr:
# provider: buffered | openai_compatible | siliconflow | dashscope # provider: buffered | openai_compatible | siliconflow | dashscope | volcengine
# dashscope defaults (if omitted): # dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-asr-flash-realtime # model: qwen3-asr-flash-realtime
# note: dashscope uses streaming ASR mode (chunk-by-chunk). # note: dashscope uses streaming ASR mode (chunk-by-chunk).
# volcengine defaults (if omitted):
# api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
# model: bigmodel
# resource_id: volc.bigasr.sauc.duration
# app_id: your volcengine app key
# api_key: your volcengine access key
# request_params:
# end_window_size: 800
# force_to_speech_time: 1000
# note: volcengine uses streaming ASR mode (chunk-by-chunk).
provider: openai_compatible provider: openai_compatible
api_key: your_asr_api_key api_key: your_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions api_url: https://api.siliconflow.cn/v1/audio/transcriptions

View File

@@ -36,10 +36,10 @@ This document defines the draft port set used to keep core runtime extensible.
- supported providers: `openai`, `openai_compatible`, `openai-compatible`, `siliconflow` - supported providers: `openai`, `openai_compatible`, `openai-compatible`, `siliconflow`
- fallback: `MockLLMService` - fallback: `MockLLMService`
- TTS: - TTS:
- supported providers: `dashscope`, `openai_compatible`, `openai-compatible`, `siliconflow` - supported providers: `dashscope`, `volcengine`, `openai_compatible`, `openai-compatible`, `siliconflow`
- fallback: `MockTTSService` - fallback: `MockTTSService`
- ASR: - ASR:
- supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope` - supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope`, `volcengine`
- fallback: `BufferedASRService` - fallback: `BufferedASRService`
## Notes ## Notes

View File

@@ -3,6 +3,7 @@
from providers.asr.buffered import BufferedASRService, MockASRService from providers.asr.buffered import BufferedASRService, MockASRService
from providers.asr.dashscope import DashScopeRealtimeASRService from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService
from providers.asr.volcengine import VolcengineRealtimeASRService
__all__ = [ __all__ = [
"BufferedASRService", "BufferedASRService",
@@ -10,4 +11,5 @@ __all__ = [
"DashScopeRealtimeASRService", "DashScopeRealtimeASRService",
"OpenAICompatibleASRService", "OpenAICompatibleASRService",
"SiliconFlowASRService", "SiliconFlowASRService",
"VolcengineRealtimeASRService",
] ]

View File

@@ -0,0 +1,666 @@
"""Volcengine realtime ASR service.
Supports both:
- Volcengine Edge Gateway realtime transcription websocket, and
- Volcengine BigASR Seed websocket at openspeech.bytedance.com/api/v3/sauc/bigmodel.
"""
from __future__ import annotations
import asyncio
import base64
import gzip
import json
import os
import uuid
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
import aiohttp
from loguru import logger
from providers.common.base import ASRResult, BaseASRService, ServiceState
VolcengineASRProtocol = Literal["gateway", "seed"]
class VolcengineRealtimeASRService(BaseASRService):
"""Realtime streaming ASR backed by Volcengine websocket APIs."""
DEFAULT_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
DEFAULT_GATEWAY_WS_URL = "wss://ai-gateway.vei.volces.com/v1/realtime"
DEFAULT_MODEL = "bigmodel"
DEFAULT_FINAL_TIMEOUT_MS = 1200
DEFAULT_SEED_RESOURCE_ID = "volc.bigasr.sauc.duration"
_SEED_FRAME_MS = 100
_SEED_PROTOCOL_VERSION = 0b0001
_SEED_FULL_CLIENT_REQUEST = 0b0001
_SEED_AUDIO_ONLY_REQUEST = 0b0010
_SEED_FULL_SERVER_RESPONSE = 0b1001
_SEED_SERVER_ACK = 0b1011
_SEED_SERVER_ERROR_RESPONSE = 0b1111
_SEED_NO_SEQUENCE = 0b0000
_SEED_POS_SEQUENCE = 0b0001
_SEED_NEG_WITH_SEQUENCE = 0b0011
_SEED_NO_SERIALIZATION = 0b0000
_SEED_JSON = 0b0001
_SEED_NO_COMPRESSION = 0b0000
_SEED_GZIP = 0b0001
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
model: Optional[str] = None,
sample_rate: int = 16000,
language: str = "auto",
app_id: Optional[str] = None,
resource_id: Optional[str] = None,
uid: Optional[str] = None,
request_params: Optional[Dict[str, Any]] = None,
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
) -> None:
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "streaming"
self.api_key = api_key or os.getenv("VOLCENGINE_ASR_API_KEY") or os.getenv("ASR_API_KEY")
self.model = str(model or os.getenv("VOLCENGINE_ASR_MODEL") or self.DEFAULT_MODEL).strip()
raw_api_url = api_url or os.getenv("VOLCENGINE_ASR_API_URL") or self.DEFAULT_WS_URL
self.protocol = self._detect_protocol(raw_api_url)
self.api_url = self._resolve_api_url(raw_api_url, self.model, self.protocol)
self.app_id = app_id or os.getenv("VOLCENGINE_ASR_APP_ID") or os.getenv("ASR_APP_ID")
self.resource_id = (
resource_id
or os.getenv("VOLCENGINE_ASR_RESOURCE_ID")
or (self.DEFAULT_SEED_RESOURCE_ID if self.protocol == "seed" else None)
)
self.uid = uid or os.getenv("VOLCENGINE_ASR_UID")
self.request_params = self._load_request_params(request_params)
self.on_transcript = on_transcript
self._session: Optional[aiohttp.ClientSession] = None
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._reader_task: Optional[asyncio.Task[None]] = None
self._running = False
self._session_ready = asyncio.Event()
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error: Optional[str] = None
self._seed_audio_buffer = bytearray()
self._seed_sequence = 1
self._seed_request_id: Optional[str] = None
self._seed_frame_bytes = max(2, int((self.sample_rate * self._SEED_FRAME_MS / 1000) * 2))
@classmethod
def _detect_protocol(cls, api_url: str) -> VolcengineASRProtocol:
parsed = urlparse(str(api_url or "").strip())
host = parsed.netloc.lower()
path = parsed.path.lower()
if "openspeech.bytedance.com" in host and "/api/v3/sauc/bigmodel" in path:
return "seed"
return "gateway"
@classmethod
def _resolve_api_url(cls, api_url: str, model: str, protocol: VolcengineASRProtocol) -> str:
raw = str(api_url or "").strip()
if not raw:
raw = cls.DEFAULT_WS_URL if protocol == "seed" else cls.DEFAULT_GATEWAY_WS_URL
if protocol != "gateway":
return raw
parsed = urlparse(raw)
query = dict(parse_qsl(parsed.query, keep_blank_values=True))
query.setdefault("model", model or cls.DEFAULT_MODEL)
return urlunparse(parsed._replace(query=urlencode(query)))
@staticmethod
def _load_request_params(request_params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(request_params, dict):
return dict(request_params)
raw = os.getenv("VOLCENGINE_ASR_REQUEST_PARAMS_JSON", "").strip()
if not raw:
return {}
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Ignoring invalid VOLCENGINE_ASR_REQUEST_PARAMS_JSON")
return {}
if isinstance(parsed, dict):
return parsed
return {}
async def connect(self) -> None:
if not self.api_key:
raise ValueError("Volcengine ASR API key not provided. Configure agent.asr.api_key in YAML.")
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
self._session = aiohttp.ClientSession(timeout=timeout)
self._running = True
if self.protocol == "gateway":
await self._connect_gateway()
logger.info(
"Volcengine gateway ASR connected: model={}, sample_rate={}, url={}",
self.model,
self.sample_rate,
self.api_url,
)
else:
if not self.app_id:
raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.")
logger.info(
"Volcengine BigASR Seed ready: model={}, sample_rate={}, resource_id={}",
self.model,
self.sample_rate,
self.resource_id,
)
self.state = ServiceState.CONNECTED
async def disconnect(self) -> None:
self._running = False
self._utterance_active = False
self._audio_sent_in_utterance = False
self._session_ready.clear()
self._seed_audio_buffer = bytearray()
self._drain_queue(self._final_queue)
self._drain_queue(self._transcript_queue)
await self._close_ws()
if self._session is not None:
await self._session.close()
self._session = None
self.state = ServiceState.DISCONNECTED
logger.info("Volcengine ASR disconnected")
async def begin_utterance(self) -> None:
self.clear_utterance()
if self.protocol == "seed":
await self._open_seed_stream()
self._utterance_active = True
async def send_audio(self, audio: bytes) -> None:
if not audio:
return
if self.protocol == "seed":
await self._send_seed_audio(audio)
return
if not self._ws:
raise RuntimeError("Volcengine ASR websocket is not connected")
if not self._utterance_active:
self._utterance_active = True
await self._ws.send_json(
{
"type": "input_audio_buffer.append",
"audio": base64.b64encode(audio).decode("ascii"),
}
)
self._audio_sent_in_utterance = True
async def end_utterance(self) -> None:
if not self._utterance_active:
return
if self.protocol == "seed":
await self._end_seed_utterance()
return
if not self._ws or not self._audio_sent_in_utterance:
return
await self._ws.send_json({"type": "input_audio_buffer.commit"})
self._utterance_active = False
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
if not self._audio_sent_in_utterance:
return ""
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
try:
return str(await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) or "").strip()
except asyncio.TimeoutError:
logger.debug("Volcengine ASR final timeout ({}ms), fallback to last interim", timeout_ms)
return str(self._last_interim_text or "").strip()
finally:
if self.protocol == "seed":
await self._close_ws()
def clear_utterance(self) -> None:
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error = None
self._seed_audio_buffer = bytearray()
self._seed_sequence = 1
self._seed_request_id = None
self._drain_queue(self._final_queue)
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
while self._running:
try:
yield await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
async def _connect_gateway(self) -> None:
assert self._session is not None
headers = {"Authorization": f"Bearer {self.api_key}"}
if self.resource_id:
headers["X-Api-Resource-Id"] = self.resource_id
self._ws = await self._session.ws_connect(self.api_url, headers=headers, heartbeat=20)
self._reader_task = asyncio.create_task(self._reader_loop())
await self._configure_gateway_session()
async def _configure_gateway_session(self) -> None:
if not self._ws:
raise RuntimeError("Volcengine ASR websocket is not initialized")
session_payload: Dict[str, Any] = {
"input_audio_format": "pcm",
"input_audio_codec": "raw",
"input_audio_sample_rate": self.sample_rate,
"input_audio_bits": 16,
"input_audio_channel": 1,
"result_type": 0,
"input_audio_transcription": {
"model": self.model,
},
}
await self._ws.send_json(
{
"type": "transcription_session.update",
"session": session_payload,
}
)
try:
await asyncio.wait_for(self._session_ready.wait(), timeout=8.0)
except asyncio.TimeoutError as exc:
raise RuntimeError("Volcengine ASR session update timeout") from exc
async def _open_seed_stream(self) -> None:
if not self._session:
raise RuntimeError("Volcengine ASR session is not initialized")
await self._close_ws()
self._seed_request_id = uuid.uuid4().hex
headers = self._build_seed_headers(self._seed_request_id)
self._ws = await self._session.ws_connect(
self.api_url,
headers=headers,
heartbeat=20,
max_msg_size=1_000_000_000,
)
self._reader_task = asyncio.create_task(self._reader_loop())
await self._ws.send_bytes(self._build_seed_start_request())
async def _send_seed_audio(self, audio: bytes) -> None:
if not self._utterance_active:
await self.begin_utterance()
if not self._ws:
raise RuntimeError("Volcengine BigASR websocket is not connected")
self._seed_audio_buffer.extend(audio)
while len(self._seed_audio_buffer) >= self._seed_frame_bytes:
chunk = bytes(self._seed_audio_buffer[: self._seed_frame_bytes])
del self._seed_audio_buffer[: self._seed_frame_bytes]
self._seed_sequence += 1
await self._ws.send_bytes(self._build_seed_audio_request(chunk, sequence=self._seed_sequence))
self._audio_sent_in_utterance = True
async def _end_seed_utterance(self) -> None:
if not self._ws:
return
if not self._audio_sent_in_utterance and not self._seed_audio_buffer:
self._utterance_active = False
return
final_chunk = bytes(self._seed_audio_buffer)
self._seed_audio_buffer = bytearray()
self._seed_sequence += 1
await self._ws.send_bytes(
self._build_seed_audio_request(final_chunk, sequence=-self._seed_sequence, is_last=True)
)
self._audio_sent_in_utterance = True
self._utterance_active = False
async def _close_ws(self) -> None:
reader_task = self._reader_task
ws = self._ws
self._reader_task = None
self._ws = None
if reader_task:
reader_task.cancel()
try:
await reader_task
except asyncio.CancelledError:
pass
if ws is not None:
await ws.close()
async def _reader_loop(self) -> None:
ws = self._ws
if ws is None:
return
try:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
if self.protocol == "gateway":
self._handle_gateway_event(msg.data)
else:
self._handle_seed_text(msg.data)
continue
if msg.type == aiohttp.WSMsgType.BINARY:
if self.protocol == "seed":
self._handle_seed_binary(msg.data)
continue
if msg.type == aiohttp.WSMsgType.ERROR:
self._last_error = str(ws.exception())
logger.error("Volcengine ASR websocket error: {}", self._last_error)
break
if msg.type in {aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE}:
break
except asyncio.CancelledError:
raise
except Exception as exc:
self._last_error = str(exc)
logger.error("Volcengine ASR reader loop failed: {}", exc)
finally:
if self._ws is ws:
self._ws = None
def _handle_gateway_event(self, message: str) -> None:
payload = self._coerce_event(message)
event_type = str(payload.get("type") or "").strip()
if not event_type:
return
if event_type in {"transcription_session.created", "transcription_session.updated"}:
self._session_ready.set()
return
if event_type == "error":
self._last_error = self._extract_text(payload, ("message", "error"))
logger.error("Volcengine ASR server error: {}", self._last_error or "unknown")
return
if event_type.endswith(".failed"):
self._last_error = self._extract_text(payload, ("message", "error", "transcript"))
logger.error("Volcengine ASR failed event: {}", self._last_error or event_type)
return
if event_type == "conversation.item.input_audio_transcription.result":
transcript = self._extract_text(payload, ("transcript", "result"))
self._emit_transcript_sync(transcript, is_final=False)
return
if event_type == "conversation.item.input_audio_transcription.delta":
transcript = self._extract_text(payload, ("delta",))
self._emit_transcript_sync(transcript, is_final=False)
return
if event_type == "conversation.item.input_audio_transcription.completed":
transcript = self._extract_text(payload, ("transcript", "result"))
self._emit_transcript_sync(transcript, is_final=True)
def _handle_seed_text(self, message: str) -> None:
payload = self._coerce_event(message)
if payload.get("type") == "error":
self._last_error = self._extract_text(payload, ("message", "error"))
logger.error("Volcengine BigASR error: {}", self._last_error or "unknown")
def _handle_seed_binary(self, message: bytes) -> None:
payload = self._parse_seed_response(message)
if payload.get("code"):
self._last_error = self._extract_text(payload, ("payload_msg",))
logger.error("Volcengine BigASR server error: {}", self._last_error or payload["code"])
return
body = payload.get("payload_msg")
if not isinstance(body, dict):
return
result = body.get("result")
if not isinstance(result, dict):
return
text = str(result.get("text") or "").strip()
if not text:
return
utterances = result.get("utterances")
if not isinstance(utterances, list) or not utterances:
return
first_utterance = utterances[0] if isinstance(utterances[0], dict) else {}
is_final = self._coerce_bool(first_utterance.get("definite")) is True
self._emit_transcript_sync(text, is_final=is_final)
def _emit_transcript_sync(self, text: str, *, is_final: bool) -> None:
cleaned = str(text or "").strip()
if not cleaned:
return
if not is_final:
self._last_interim_text = cleaned
else:
self._last_interim_text = ""
result = ASRResult(text=cleaned, is_final=is_final)
try:
self._transcript_queue.put_nowait(result)
except asyncio.QueueFull:
logger.debug("Volcengine ASR transcript queue full; dropping transcript")
if is_final:
try:
self._final_queue.put_nowait(cleaned)
except asyncio.QueueFull:
logger.debug("Volcengine ASR final queue full; dropping transcript")
if self.on_transcript:
asyncio.create_task(self.on_transcript(cleaned, is_final))
def _build_seed_headers(self, request_id: str) -> Dict[str, str]:
if not self.app_id:
raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.")
if not self.api_key:
raise ValueError("Volcengine ASR api_key not provided. Configure agent.asr.api_key in YAML.")
return {
"X-Api-App-Key": str(self.app_id),
"X-Api-Access-Key": str(self.api_key),
"X-Api-Resource-Id": str(self.resource_id or self.DEFAULT_SEED_RESOURCE_ID),
"X-Api-Request-Id": str(request_id),
}
def _build_seed_start_payload(self) -> Dict[str, Any]:
user_payload: Dict[str, Any] = {"uid": str(self.uid or self._seed_request_id or self.app_id or uuid.uuid4().hex)}
audio_payload: Dict[str, Any] = {
"format": "pcm",
"rate": self.sample_rate,
"bits": 16,
"channels": 1,
"codec": "raw",
}
if self.language and self.language != "auto":
audio_payload["language"] = self.language
request_payload: Dict[str, Any] = {
"model_name": self.model or self.DEFAULT_MODEL,
"enable_itn": False,
"enable_punc": True,
"enable_ddc": False,
"show_utterance": True,
"result_type": "single",
"vad_segment_duration": 3000,
"end_window_size": 500,
"force_to_speech_time": 1000,
}
extra = dict(self.request_params)
user_payload.update(self._as_dict(extra.pop("user", None)))
audio_payload.update(self._as_dict(extra.pop("audio", None)))
request_payload.update(self._as_dict(extra.pop("request", None)))
request_payload.update(extra)
return {
"user": user_payload,
"audio": audio_payload,
"request": request_payload,
}
def _build_seed_start_request(self) -> bytes:
payload = gzip.compress(json.dumps(self._build_seed_start_payload()).encode("utf-8"))
frame = bytearray(
self._build_seed_header(
message_type=self._SEED_FULL_CLIENT_REQUEST,
message_type_specific_flags=self._SEED_POS_SEQUENCE,
)
)
frame.extend((1).to_bytes(4, "big", signed=True))
frame.extend(len(payload).to_bytes(4, "big"))
frame.extend(payload)
return bytes(frame)
def _build_seed_audio_request(self, chunk: bytes, *, sequence: int, is_last: bool = False) -> bytes:
payload = gzip.compress(chunk)
frame = bytearray(
self._build_seed_header(
message_type=self._SEED_AUDIO_ONLY_REQUEST,
message_type_specific_flags=self._SEED_NEG_WITH_SEQUENCE if is_last else self._SEED_POS_SEQUENCE,
)
)
frame.extend(int(sequence).to_bytes(4, "big", signed=True))
frame.extend(len(payload).to_bytes(4, "big"))
frame.extend(payload)
return bytes(frame)
@classmethod
def _build_seed_header(
cls,
*,
message_type: int,
message_type_specific_flags: int,
serial_method: int = _SEED_JSON,
compression_type: int = _SEED_GZIP,
reserved_data: int = 0x00,
) -> bytes:
header = bytearray()
header.append((cls._SEED_PROTOCOL_VERSION << 4) | 0b0001)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
return bytes(header)
@classmethod
def _parse_seed_response(cls, response: bytes) -> Dict[str, Any]:
header_size = response[0] & 0x0F
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0F
serialization_method = response[2] >> 4
compression_type = response[2] & 0x0F
payload = response[header_size * 4 :]
result: Dict[str, Any] = {"is_last_package": False}
payload_message: Any = None
if message_type_specific_flags & 0x01:
result["payload_sequence"] = int.from_bytes(payload[:4], "big", signed=True)
payload = payload[4:]
if message_type_specific_flags & 0x02:
result["is_last_package"] = True
if message_type == cls._SEED_FULL_SERVER_RESPONSE:
result["payload_size"] = int.from_bytes(payload[:4], "big", signed=True)
payload_message = payload[4:]
elif message_type == cls._SEED_SERVER_ACK:
result["seq"] = int.from_bytes(payload[:4], "big", signed=True)
if len(payload) >= 8:
result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False)
payload_message = payload[8:]
elif message_type == cls._SEED_SERVER_ERROR_RESPONSE:
result["code"] = int.from_bytes(payload[:4], "big", signed=False)
result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False)
payload_message = payload[8:]
if payload_message is None:
return result
if compression_type == cls._SEED_GZIP:
payload_message = gzip.decompress(payload_message)
if serialization_method == cls._SEED_JSON:
payload_message = json.loads(payload_message.decode("utf-8"))
elif serialization_method != cls._SEED_NO_SERIALIZATION:
payload_message = payload_message.decode("utf-8")
result["payload_msg"] = payload_message
return result
@staticmethod
def _coerce_event(message: Any) -> Dict[str, Any]:
if isinstance(message, dict):
return message
if isinstance(message, str):
try:
loaded = json.loads(message)
if isinstance(loaded, dict):
return loaded
except json.JSONDecodeError:
return {"type": "raw", "message": message}
return {"type": "raw", "message": str(message)}
@staticmethod
def _extract_text(payload: Dict[str, Any], keys: tuple[str, ...]) -> str:
for key in keys:
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, dict):
for nested_key in ("message", "text", "transcript", "result", "delta"):
nested = value.get(nested_key)
if isinstance(nested, str) and nested.strip():
return nested.strip()
return ""
@staticmethod
def _coerce_bool(value: Any) -> Optional[bool]:
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return None
@staticmethod
def _as_dict(value: Any) -> Dict[str, Any]:
if isinstance(value, dict):
return dict(value)
return {}
@staticmethod
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
while True:
try:
queue.get_nowait()
except asyncio.QueueEmpty:
break

View File

@@ -17,14 +17,17 @@ from runtime.ports import (
) )
from providers.asr.buffered import BufferedASRService from providers.asr.buffered import BufferedASRService
from providers.asr.dashscope import DashScopeRealtimeASRService from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.volcengine import VolcengineRealtimeASRService
from providers.tts.dashscope import DashScopeTTSService from providers.tts.dashscope import DashScopeTTSService
from providers.llm.openai import MockLLMService, OpenAILLMService from providers.llm.openai import MockLLMService, OpenAILLMService
from providers.asr.openai_compatible import OpenAICompatibleASRService from providers.asr.openai_compatible import OpenAICompatibleASRService
from providers.tts.openai_compatible import OpenAICompatibleTTSService from providers.tts.openai_compatible import OpenAICompatibleTTSService
from providers.tts.mock import MockTTSService from providers.tts.mock import MockTTSService
from providers.tts.volcengine import VolcengineTTSService
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} _OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
_DASHSCOPE_PROVIDERS = {"dashscope"} _DASHSCOPE_PROVIDERS = {"dashscope"}
_VOLCENGINE_PROVIDERS = {"volcengine"}
_SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS} _SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS}
@@ -37,6 +40,10 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
_DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime" _DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime"
_DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B" _DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
_DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" _DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
_DEFAULT_VOLCENGINE_TTS_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
_DEFAULT_VOLCENGINE_TTS_RESOURCE_ID = "seed-tts-2.0"
_DEFAULT_VOLCENGINE_ASR_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
_DEFAULT_VOLCENGINE_ASR_MODEL = "bigmodel"
@staticmethod @staticmethod
def _normalize_provider(provider: Any) -> str: def _normalize_provider(provider: Any) -> str:
@@ -81,6 +88,19 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
speed=spec.speed, speed=spec.speed,
) )
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
return VolcengineTTSService(
api_key=spec.api_key,
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_TTS_URL,
voice=spec.voice,
model=spec.model,
app_id=spec.app_id,
resource_id=spec.resource_id or self._DEFAULT_VOLCENGINE_TTS_RESOURCE_ID,
uid=spec.uid,
sample_rate=spec.sample_rate,
speed=spec.speed,
)
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
return OpenAICompatibleTTSService( return OpenAICompatibleTTSService(
api_key=spec.api_key, api_key=spec.api_key,
@@ -110,6 +130,20 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
on_transcript=spec.on_transcript, on_transcript=spec.on_transcript,
) )
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
return VolcengineRealtimeASRService(
api_key=spec.api_key,
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_ASR_REALTIME_URL,
model=spec.model or self._DEFAULT_VOLCENGINE_ASR_MODEL,
sample_rate=spec.sample_rate,
language=spec.language,
app_id=spec.app_id,
resource_id=spec.resource_id,
uid=spec.uid,
request_params=spec.request_params,
on_transcript=spec.on_transcript,
)
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
return OpenAICompatibleASRService( return OpenAICompatibleASRService(
api_key=spec.api_key, api_key=spec.api_key,

View File

@@ -1 +1,5 @@
"""TTS providers.""" """TTS providers."""
from providers.tts.volcengine import VolcengineTTSService
__all__ = ["VolcengineTTSService"]

View File

@@ -0,0 +1,219 @@
"""Volcengine TTS service.
Uses Volcengine's unidirectional HTTP streaming TTS API and adapts streamed
base64 audio chunks into engine-native ``TTSChunk`` events.
"""
from __future__ import annotations
import asyncio
import base64
import codecs
import json
import os
import uuid
from typing import Any, AsyncIterator, Optional
import aiohttp
from loguru import logger
from providers.common.base import BaseTTSService, ServiceState, TTSChunk
class VolcengineTTSService(BaseTTSService):
"""Streaming TTS adapter for Volcengine's HTTP v3 API."""
DEFAULT_API_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
DEFAULT_RESOURCE_ID = "seed-tts-2.0"
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
voice: str = "zh_female_shuangkuaisisi_moon_bigtts",
model: Optional[str] = None,
app_id: Optional[str] = None,
resource_id: Optional[str] = None,
uid: Optional[str] = None,
sample_rate: int = 16000,
speed: float = 1.0,
) -> None:
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
self.api_key = api_key or os.getenv("VOLCENGINE_TTS_API_KEY") or os.getenv("TTS_API_KEY")
self.api_url = api_url or os.getenv("VOLCENGINE_TTS_API_URL") or self.DEFAULT_API_URL
self.model = str(model or os.getenv("VOLCENGINE_TTS_MODEL") or "").strip() or None
self.app_id = app_id or os.getenv("VOLCENGINE_TTS_APP_ID") or os.getenv("TTS_APP_ID")
self.resource_id = resource_id or os.getenv("VOLCENGINE_TTS_RESOURCE_ID") or self.DEFAULT_RESOURCE_ID
self.uid = uid or os.getenv("VOLCENGINE_TTS_UID")
self._session: Optional[aiohttp.ClientSession] = None
self._cancel_event = asyncio.Event()
self._synthesis_lock = asyncio.Lock()
self._pending_audio: list[bytes] = []
async def connect(self) -> None:
if not self.api_key:
raise ValueError("Volcengine TTS API key not provided. Configure agent.tts.api_key in YAML.")
if not self.app_id:
raise ValueError("Volcengine TTS app_id not provided. Configure agent.tts.app_id in YAML.")
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
self._session = aiohttp.ClientSession(timeout=timeout)
self.state = ServiceState.CONNECTED
logger.info(
"Volcengine TTS service ready: speaker={}, sample_rate={}, resource_id={}",
self.voice,
self.sample_rate,
self.resource_id,
)
async def disconnect(self) -> None:
self._cancel_event.set()
if self._session is not None:
await self._session.close()
self._session = None
self.state = ServiceState.DISCONNECTED
logger.info("Volcengine TTS service disconnected")
async def synthesize(self, text: str) -> bytes:
audio = b""
async for chunk in self.synthesize_stream(text):
audio += chunk.audio
return audio
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
if not self._session:
raise RuntimeError("Volcengine TTS service not connected")
if not text.strip():
return
async with self._synthesis_lock:
self._cancel_event.clear()
headers = {
"Content-Type": "application/json",
"X-Api-App-Key": str(self.app_id),
"X-Api-Access-Key": str(self.api_key),
"X-Api-Resource-Id": str(self.resource_id),
"X-Api-Request-Id": str(uuid.uuid4()),
}
payload = {
"user": {
"uid": str(self.uid or self.app_id),
},
"req_params": {
"text": text,
"speaker": self.voice,
"audio_params": {
"format": "pcm",
"sample_rate": self.sample_rate,
"speech_rate": self._speech_rate_percent(self.speed),
},
},
}
if self.model:
payload["req_params"]["model"] = self.model
chunk_size = max(1, self.sample_rate * 2 // 10)
audio_buffer = b""
pending_chunk: Optional[bytes] = None
try:
async with self._session.post(self.api_url, headers=headers, json=payload) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Volcengine TTS error {response.status}: {error_text}")
async for audio_bytes in self._iter_audio_bytes(response):
if self._cancel_event.is_set():
logger.info("Volcengine TTS synthesis cancelled")
return
audio_buffer += audio_bytes
while len(audio_buffer) >= chunk_size:
emitted = audio_buffer[:chunk_size]
audio_buffer = audio_buffer[chunk_size:]
if pending_chunk is not None:
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
pending_chunk = emitted
if self._cancel_event.is_set():
return
if pending_chunk is not None:
if audio_buffer:
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
pending_chunk = None
else:
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
pending_chunk = None
if audio_buffer:
yield TTSChunk(audio=audio_buffer, sample_rate=self.sample_rate, is_final=True)
except asyncio.CancelledError:
logger.info("Volcengine TTS synthesis cancelled via asyncio")
raise
except Exception as exc:
logger.error("Volcengine TTS synthesis error: {}", exc)
raise
async def cancel(self) -> None:
self._cancel_event.set()
async def _iter_audio_bytes(self, response: aiohttp.ClientResponse) -> AsyncIterator[bytes]:
decoder = json.JSONDecoder()
utf8_decoder = codecs.getincrementaldecoder("utf-8")()
text_buffer = ""
self._pending_audio.clear()
async for raw_chunk in response.content.iter_any():
text_buffer += utf8_decoder.decode(raw_chunk)
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
while self._pending_audio:
yield self._pending_audio.pop(0)
text_buffer += utf8_decoder.decode(b"", final=True)
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
while self._pending_audio:
yield self._pending_audio.pop(0)
def _yield_audio_payloads(self, decoder: json.JSONDecoder, text_buffer: str) -> str:
while True:
stripped = text_buffer.lstrip()
if not stripped:
return ""
if len(stripped) != len(text_buffer):
text_buffer = stripped
try:
payload, idx = decoder.raw_decode(text_buffer)
except json.JSONDecodeError:
return text_buffer
text_buffer = text_buffer[idx:]
audio = self._extract_audio_bytes(payload)
if audio:
self._pending_audio.append(audio)
def _extract_audio_bytes(self, payload: Any) -> bytes:
if not isinstance(payload, dict):
return b""
code = payload.get("code")
if code not in (None, 0, 20000000):
message = str(payload.get("message") or "unknown error")
raise RuntimeError(f"Volcengine TTS stream error {code}: {message}")
encoded = payload.get("data")
if isinstance(encoded, str) and encoded.strip():
try:
return base64.b64decode(encoded)
except Exception as exc:
logger.warning("Failed to decode Volcengine TTS audio chunk: {}", exc)
return b""
@staticmethod
def _speech_rate_percent(speed: float) -> int:
clamped = max(0.5, min(2.0, float(speed or 1.0)))
return int(round((clamped - 1.0) * 100))

View File

@@ -793,6 +793,23 @@ class DuplexPipeline:
return False return False
return None return None
@staticmethod
def _coerce_json_object(value: Any) -> Optional[Dict[str, Any]]:
if isinstance(value, dict):
return dict(value)
if isinstance(value, str):
raw = value.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Ignoring invalid JSON object config: {}", raw[:120])
return None
if isinstance(parsed, dict):
return parsed
return None
@staticmethod @staticmethod
def _is_dashscope_tts_provider(provider: Any) -> bool: def _is_dashscope_tts_provider(provider: Any) -> bool:
normalized = str(provider or "").strip().lower() normalized = str(provider or "").strip().lower()
@@ -804,7 +821,7 @@ class DuplexPipeline:
if normalized_mode in {"offline", "streaming"}: if normalized_mode in {"offline", "streaming"}:
return normalized_mode # type: ignore[return-value] return normalized_mode # type: ignore[return-value]
normalized_provider = str(provider or "").strip().lower() normalized_provider = str(provider or "").strip().lower()
if normalized_provider == "dashscope": if normalized_provider in {"dashscope", "volcengine"}:
return "streaming" return "streaming"
return "offline" return "offline"
@@ -963,6 +980,10 @@ class DuplexPipeline:
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
tts_model = self._runtime_tts.get("model") or settings.tts_model tts_model = self._runtime_tts.get("model") or settings.tts_model
tts_app_id = self._runtime_tts.get("appId") or settings.tts_app_id
tts_resource_id = self._runtime_tts.get("resourceId") or settings.tts_resource_id
tts_cluster = self._runtime_tts.get("cluster") or settings.tts_cluster
tts_uid = self._runtime_tts.get("uid") or settings.tts_uid
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
tts_mode = self._resolved_dashscope_tts_mode() tts_mode = self._resolved_dashscope_tts_mode()
runtime_mode = str(self._runtime_tts.get("mode") or "").strip() runtime_mode = str(self._runtime_tts.get("mode") or "").strip()
@@ -978,6 +999,10 @@ class DuplexPipeline:
api_url=str(tts_api_url).strip() if tts_api_url else None, api_url=str(tts_api_url).strip() if tts_api_url else None,
voice=str(tts_voice), voice=str(tts_voice),
model=str(tts_model).strip() if tts_model else None, model=str(tts_model).strip() if tts_model else None,
app_id=str(tts_app_id).strip() if tts_app_id else None,
resource_id=str(tts_resource_id).strip() if tts_resource_id else None,
cluster=str(tts_cluster).strip() if tts_cluster else None,
uid=str(tts_uid).strip() if tts_uid else None,
sample_rate=settings.sample_rate, sample_rate=settings.sample_rate,
speed=tts_speed, speed=tts_speed,
mode=str(tts_mode), mode=str(tts_mode),
@@ -1006,6 +1031,13 @@ class DuplexPipeline:
asr_api_key = self._runtime_asr.get("apiKey") asr_api_key = self._runtime_asr.get("apiKey")
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
asr_model = self._runtime_asr.get("model") or settings.asr_model asr_model = self._runtime_asr.get("model") or settings.asr_model
asr_app_id = self._runtime_asr.get("appId") or settings.asr_app_id
asr_resource_id = self._runtime_asr.get("resourceId") or settings.asr_resource_id
asr_cluster = self._runtime_asr.get("cluster") or settings.asr_cluster
asr_uid = self._runtime_asr.get("uid") or settings.asr_uid
asr_request_params = self._coerce_json_object(self._runtime_asr.get("requestParams"))
if asr_request_params is None:
asr_request_params = self._coerce_json_object(settings.asr_request_params_json)
asr_enable_interim = self._coerce_bool(self._runtime_asr.get("enableInterim")) asr_enable_interim = self._coerce_bool(self._runtime_asr.get("enableInterim"))
if asr_enable_interim is None: if asr_enable_interim is None:
asr_enable_interim = bool(settings.asr_enable_interim) asr_enable_interim = bool(settings.asr_enable_interim)
@@ -1022,6 +1054,11 @@ class DuplexPipeline:
api_key=str(asr_api_key).strip() if asr_api_key else None, api_key=str(asr_api_key).strip() if asr_api_key else None,
api_url=str(asr_api_url).strip() if asr_api_url else None, api_url=str(asr_api_url).strip() if asr_api_url else None,
model=str(asr_model).strip() if asr_model else None, model=str(asr_model).strip() if asr_model else None,
app_id=str(asr_app_id).strip() if asr_app_id else None,
resource_id=str(asr_resource_id).strip() if asr_resource_id else None,
cluster=str(asr_cluster).strip() if asr_cluster else None,
uid=str(asr_uid).strip() if asr_uid else None,
request_params=asr_request_params,
enable_interim=asr_enable_interim, enable_interim=asr_enable_interim,
interim_interval_ms=asr_interim_interval, interim_interval_ms=asr_interim_interval,
min_audio_for_interim_ms=asr_min_audio_ms, min_audio_for_interim_ms=asr_min_audio_ms,

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import AsyncIterator, Awaitable, Callable, Literal, Optional, Protocol from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional, Protocol
from providers.common.base import ASRResult from providers.common.base import ASRResult
@@ -22,6 +22,11 @@ class ASRServiceSpec:
api_key: Optional[str] = None api_key: Optional[str] = None
api_url: Optional[str] = None api_url: Optional[str] = None
model: Optional[str] = None model: Optional[str] = None
app_id: Optional[str] = None
resource_id: Optional[str] = None
cluster: Optional[str] = None
uid: Optional[str] = None
request_params: Optional[Dict[str, Any]] = None
enable_interim: bool = False enable_interim: bool = False
interim_interval_ms: int = 500 interim_interval_ms: int = 500
min_audio_for_interim_ms: int = 300 min_audio_for_interim_ms: int = 300

View File

@@ -19,6 +19,10 @@ class TTSServiceSpec:
api_key: Optional[str] = None api_key: Optional[str] = None
api_url: Optional[str] = None api_url: Optional[str] = None
model: Optional[str] = None model: Optional[str] = None
app_id: Optional[str] = None
resource_id: Optional[str] = None
cluster: Optional[str] = None
uid: Optional[str] = None
mode: str = "commit" mode: str = "commit"

View File

@@ -1,6 +1,7 @@
from providers.asr.buffered import BufferedASRService from providers.asr.buffered import BufferedASRService
from providers.asr.dashscope import DashScopeRealtimeASRService from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.openai_compatible import OpenAICompatibleASRService from providers.asr.openai_compatible import OpenAICompatibleASRService
from providers.asr.volcengine import VolcengineRealtimeASRService
from providers.factory.default import DefaultRealtimeServiceFactory from providers.factory.default import DefaultRealtimeServiceFactory
from runtime.ports import ASRServiceSpec from runtime.ports import ASRServiceSpec
@@ -35,6 +36,29 @@ def test_create_asr_service_openai_compatible_returns_offline_provider():
assert service.enable_interim is False assert service.enable_interim is False
def test_create_asr_service_volcengine_returns_streaming_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service(
ASRServiceSpec(
provider="volcengine",
mode="streaming",
sample_rate=16000,
api_key="test-key",
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
model="bigmodel",
app_id="app-1",
uid="caller-1",
request_params={"end_window_size": 800},
)
)
assert isinstance(service, VolcengineRealtimeASRService)
assert service.mode == "streaming"
assert service.protocol == "seed"
assert service.app_id == "app-1"
assert service.uid == "caller-1"
assert service.request_params["end_window_size"] == 800
def test_create_asr_service_fallback_buffered_for_unsupported_provider(): def test_create_asr_service_fallback_buffered_for_unsupported_provider():
factory = DefaultRealtimeServiceFactory() factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service( service = factory.create_asr_service(

View File

@@ -227,6 +227,62 @@ async def test_with_backend_url_uses_backend_for_assistant_config(monkeypatch, t
assert payload["assistant"]["systemPrompt"] == "backend prompt" assert payload["assistant"]["systemPrompt"] == "backend prompt"
def test_translate_agent_schema_maps_volcengine_fields():
payload = {
"agent": {
"tts": {
"provider": "volcengine",
"api_key": "tts-key",
"api_url": "https://openspeech.bytedance.com/api/v3/tts/unidirectional",
"app_id": "app-123",
"resource_id": "seed-tts-2.0",
"uid": "caller-1",
"voice": "zh_female_shuangkuaisisi_moon_bigtts",
"speed": 1.1,
},
"asr": {
"provider": "volcengine",
"api_key": "asr-key",
"api_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
"model": "bigmodel",
"app_id": "app-123",
"resource_id": "volc.bigasr.sauc.duration",
"uid": "caller-1",
"request_params": {
"end_window_size": 800,
"force_to_speech_time": 1000,
},
},
}
}
translated = LocalYamlAssistantConfigAdapter._translate_agent_schema("assistant_demo", payload)
assert translated is not None
assert translated["services"]["tts"] == {
"provider": "volcengine",
"apiKey": "tts-key",
"baseUrl": "https://openspeech.bytedance.com/api/v3/tts/unidirectional",
"voice": "zh_female_shuangkuaisisi_moon_bigtts",
"appId": "app-123",
"resourceId": "seed-tts-2.0",
"uid": "caller-1",
"speed": 1.1,
}
assert translated["services"]["asr"] == {
"provider": "volcengine",
"model": "bigmodel",
"apiKey": "asr-key",
"baseUrl": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
"appId": "app-123",
"resourceId": "volc.bigasr.sauc.duration",
"uid": "caller-1",
"requestParams": {
"end_window_size": 800,
"force_to_speech_time": 1000,
},
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_backend_mode_disabled_uses_local_assistant_config_even_with_url(monkeypatch, tmp_path): async def test_backend_mode_disabled_uses_local_assistant_config_even_with_url(monkeypatch, tmp_path):
class _FailIfCalledClientSession: class _FailIfCalledClientSession:

View File

@@ -0,0 +1,45 @@
from providers.factory.default import DefaultRealtimeServiceFactory
from providers.tts.mock import MockTTSService
from providers.tts.openai_compatible import OpenAICompatibleTTSService
from providers.tts.volcengine import VolcengineTTSService
from runtime.ports import TTSServiceSpec
def test_create_tts_service_volcengine_returns_native_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_tts_service(
TTSServiceSpec(
provider="volcengine",
api_key="test-key",
app_id="app-1",
resource_id="seed-tts-2.0",
voice="zh_female_shuangkuaisisi_moon_bigtts",
sample_rate=16000,
)
)
assert isinstance(service, VolcengineTTSService)
def test_create_tts_service_openai_compatible_returns_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_tts_service(
TTSServiceSpec(
provider="openai_compatible",
api_key="test-key",
voice="anna",
sample_rate=16000,
)
)
assert isinstance(service, OpenAICompatibleTTSService)
def test_create_tts_service_fallbacks_to_mock_without_key():
factory = DefaultRealtimeServiceFactory()
service = factory.create_tts_service(
TTSServiceSpec(
provider="volcengine",
voice="anna",
sample_rate=16000,
)
)
assert isinstance(service, MockTTSService)

View File

@@ -0,0 +1,86 @@
import gzip
import json
from providers.asr.volcengine import VolcengineRealtimeASRService
def test_volcengine_seed_protocol_defaults_and_headers():
service = VolcengineRealtimeASRService(
api_key="access-token",
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
app_id="app-1",
uid="caller-1",
)
assert service.protocol == "seed"
assert service.resource_id == "volc.bigasr.sauc.duration"
headers = service._build_seed_headers("req-1")
assert headers == {
"X-Api-App-Key": "app-1",
"X-Api-Access-Key": "access-token",
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Request-Id": "req-1",
}
def test_volcengine_seed_start_payload_merges_request_params():
service = VolcengineRealtimeASRService(
api_key="access-token",
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
app_id="app-1",
uid="caller-1",
language="zh-CN",
request_params={
"request": {
"end_window_size": 800,
"force_to_speech_time": 1000,
"context": "{\"hotwords\":[{\"word\":\"doubao\"}]}",
},
"audio": {"codec": "raw"},
},
)
payload = service._build_seed_start_payload()
assert payload["user"] == {"uid": "caller-1"}
assert payload["audio"] == {
"format": "pcm",
"rate": 16000,
"bits": 16,
"channels": 1,
"codec": "raw",
"language": "zh-CN",
}
assert payload["request"]["model_name"] == "bigmodel"
assert payload["request"]["end_window_size"] == 800
assert payload["request"]["force_to_speech_time"] == 1000
assert payload["request"]["context"] == "{\"hotwords\":[{\"word\":\"doubao\"}]}"
def test_volcengine_seed_start_request_encodes_gzip_json_payload():
service = VolcengineRealtimeASRService(
api_key="access-token",
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
app_id="app-1",
uid="caller-1",
)
frame = service._build_seed_start_request()
assert frame[0] == 0x11
assert frame[1] == 0x11
payload_length = int.from_bytes(frame[8:12], "big")
payload = json.loads(gzip.decompress(frame[12 : 12 + payload_length]).decode("utf-8"))
assert payload["user"]["uid"] == "caller-1"
assert payload["request"]["model_name"] == "bigmodel"
def test_volcengine_gateway_protocol_keeps_model_query():
service = VolcengineRealtimeASRService(
api_key="access-token",
api_url="wss://ai-gateway.vei.volces.com/v1/realtime",
model="bigmodel",
)
assert service.protocol == "gateway"
assert service.api_url == "wss://ai-gateway.vei.volces.com/v1/realtime?model=bigmodel"