Add Volcengine support for TTS and ASR services
- Introduced Volcengine as a new provider for both TTS and ASR services. - Updated configuration files to include Volcengine-specific parameters such as app_id, resource_id, and uid. - Enhanced the ASR service to support streaming mode with Volcengine's API. - Modified existing tests to validate the integration of Volcengine services. - Updated documentation to reflect the addition of Volcengine as a supported provider for TTS and ASR. - Refactored service factory to accommodate Volcengine alongside existing providers.
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from providers.asr.buffered import BufferedASRService, MockASRService
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService
|
||||
from providers.asr.volcengine import VolcengineRealtimeASRService
|
||||
|
||||
__all__ = [
|
||||
"BufferedASRService",
|
||||
@@ -10,4 +11,5 @@ __all__ = [
|
||||
"DashScopeRealtimeASRService",
|
||||
"OpenAICompatibleASRService",
|
||||
"SiliconFlowASRService",
|
||||
"VolcengineRealtimeASRService",
|
||||
]
|
||||
|
||||
666
engine/providers/asr/volcengine.py
Normal file
666
engine/providers/asr/volcengine.py
Normal file
@@ -0,0 +1,666 @@
|
||||
"""Volcengine realtime ASR service.
|
||||
|
||||
Supports both:
|
||||
- Volcengine Edge Gateway realtime transcription websocket, and
|
||||
- Volcengine BigASR Seed websocket at openspeech.bytedance.com/api/v3/sauc/bigmodel.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import ASRResult, BaseASRService, ServiceState
|
||||
|
||||
VolcengineASRProtocol = Literal["gateway", "seed"]
|
||||
|
||||
|
||||
class VolcengineRealtimeASRService(BaseASRService):
|
||||
"""Realtime streaming ASR backed by Volcengine websocket APIs."""
|
||||
|
||||
DEFAULT_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
|
||||
DEFAULT_GATEWAY_WS_URL = "wss://ai-gateway.vei.volces.com/v1/realtime"
|
||||
DEFAULT_MODEL = "bigmodel"
|
||||
DEFAULT_FINAL_TIMEOUT_MS = 1200
|
||||
DEFAULT_SEED_RESOURCE_ID = "volc.bigasr.sauc.duration"
|
||||
_SEED_FRAME_MS = 100
|
||||
_SEED_PROTOCOL_VERSION = 0b0001
|
||||
_SEED_FULL_CLIENT_REQUEST = 0b0001
|
||||
_SEED_AUDIO_ONLY_REQUEST = 0b0010
|
||||
_SEED_FULL_SERVER_RESPONSE = 0b1001
|
||||
_SEED_SERVER_ACK = 0b1011
|
||||
_SEED_SERVER_ERROR_RESPONSE = 0b1111
|
||||
_SEED_NO_SEQUENCE = 0b0000
|
||||
_SEED_POS_SEQUENCE = 0b0001
|
||||
_SEED_NEG_WITH_SEQUENCE = 0b0011
|
||||
_SEED_NO_SERIALIZATION = 0b0000
|
||||
_SEED_JSON = 0b0001
|
||||
_SEED_NO_COMPRESSION = 0b0000
|
||||
_SEED_GZIP = 0b0001
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
app_id: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
uid: Optional[str] = None,
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "streaming"
|
||||
self.api_key = api_key or os.getenv("VOLCENGINE_ASR_API_KEY") or os.getenv("ASR_API_KEY")
|
||||
self.model = str(model or os.getenv("VOLCENGINE_ASR_MODEL") or self.DEFAULT_MODEL).strip()
|
||||
raw_api_url = api_url or os.getenv("VOLCENGINE_ASR_API_URL") or self.DEFAULT_WS_URL
|
||||
self.protocol = self._detect_protocol(raw_api_url)
|
||||
self.api_url = self._resolve_api_url(raw_api_url, self.model, self.protocol)
|
||||
self.app_id = app_id or os.getenv("VOLCENGINE_ASR_APP_ID") or os.getenv("ASR_APP_ID")
|
||||
self.resource_id = (
|
||||
resource_id
|
||||
or os.getenv("VOLCENGINE_ASR_RESOURCE_ID")
|
||||
or (self.DEFAULT_SEED_RESOURCE_ID if self.protocol == "seed" else None)
|
||||
)
|
||||
self.uid = uid or os.getenv("VOLCENGINE_ASR_UID")
|
||||
self.request_params = self._load_request_params(request_params)
|
||||
self.on_transcript = on_transcript
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self._reader_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
self._running = False
|
||||
self._session_ready = asyncio.Event()
|
||||
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
|
||||
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
|
||||
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._last_interim_text = ""
|
||||
self._last_error: Optional[str] = None
|
||||
|
||||
self._seed_audio_buffer = bytearray()
|
||||
self._seed_sequence = 1
|
||||
self._seed_request_id: Optional[str] = None
|
||||
self._seed_frame_bytes = max(2, int((self.sample_rate * self._SEED_FRAME_MS / 1000) * 2))
|
||||
|
||||
@classmethod
|
||||
def _detect_protocol(cls, api_url: str) -> VolcengineASRProtocol:
|
||||
parsed = urlparse(str(api_url or "").strip())
|
||||
host = parsed.netloc.lower()
|
||||
path = parsed.path.lower()
|
||||
if "openspeech.bytedance.com" in host and "/api/v3/sauc/bigmodel" in path:
|
||||
return "seed"
|
||||
return "gateway"
|
||||
|
||||
@classmethod
|
||||
def _resolve_api_url(cls, api_url: str, model: str, protocol: VolcengineASRProtocol) -> str:
|
||||
raw = str(api_url or "").strip()
|
||||
if not raw:
|
||||
raw = cls.DEFAULT_WS_URL if protocol == "seed" else cls.DEFAULT_GATEWAY_WS_URL
|
||||
if protocol != "gateway":
|
||||
return raw
|
||||
|
||||
parsed = urlparse(raw)
|
||||
query = dict(parse_qsl(parsed.query, keep_blank_values=True))
|
||||
query.setdefault("model", model or cls.DEFAULT_MODEL)
|
||||
return urlunparse(parsed._replace(query=urlencode(query)))
|
||||
|
||||
@staticmethod
|
||||
def _load_request_params(request_params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
if isinstance(request_params, dict):
|
||||
return dict(request_params)
|
||||
raw = os.getenv("VOLCENGINE_ASR_REQUEST_PARAMS_JSON", "").strip()
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Ignoring invalid VOLCENGINE_ASR_REQUEST_PARAMS_JSON")
|
||||
return {}
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not self.api_key:
|
||||
raise ValueError("Volcengine ASR API key not provided. Configure agent.asr.api_key in YAML.")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
self._running = True
|
||||
|
||||
if self.protocol == "gateway":
|
||||
await self._connect_gateway()
|
||||
logger.info(
|
||||
"Volcengine gateway ASR connected: model={}, sample_rate={}, url={}",
|
||||
self.model,
|
||||
self.sample_rate,
|
||||
self.api_url,
|
||||
)
|
||||
else:
|
||||
if not self.app_id:
|
||||
raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.")
|
||||
logger.info(
|
||||
"Volcengine BigASR Seed ready: model={}, sample_rate={}, resource_id={}",
|
||||
self.model,
|
||||
self.sample_rate,
|
||||
self.resource_id,
|
||||
)
|
||||
|
||||
self.state = ServiceState.CONNECTED
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._running = False
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._session_ready.clear()
|
||||
self._seed_audio_buffer = bytearray()
|
||||
self._drain_queue(self._final_queue)
|
||||
self._drain_queue(self._transcript_queue)
|
||||
|
||||
await self._close_ws()
|
||||
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Volcengine ASR disconnected")
|
||||
|
||||
async def begin_utterance(self) -> None:
|
||||
self.clear_utterance()
|
||||
if self.protocol == "seed":
|
||||
await self._open_seed_stream()
|
||||
self._utterance_active = True
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
if not audio:
|
||||
return
|
||||
|
||||
if self.protocol == "seed":
|
||||
await self._send_seed_audio(audio)
|
||||
return
|
||||
|
||||
if not self._ws:
|
||||
raise RuntimeError("Volcengine ASR websocket is not connected")
|
||||
if not self._utterance_active:
|
||||
self._utterance_active = True
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(audio).decode("ascii"),
|
||||
}
|
||||
)
|
||||
self._audio_sent_in_utterance = True
|
||||
|
||||
async def end_utterance(self) -> None:
|
||||
if not self._utterance_active:
|
||||
return
|
||||
|
||||
if self.protocol == "seed":
|
||||
await self._end_seed_utterance()
|
||||
return
|
||||
|
||||
if not self._ws or not self._audio_sent_in_utterance:
|
||||
return
|
||||
await self._ws.send_json({"type": "input_audio_buffer.commit"})
|
||||
self._utterance_active = False
|
||||
|
||||
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
|
||||
if not self._audio_sent_in_utterance:
|
||||
return ""
|
||||
|
||||
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
|
||||
try:
|
||||
return str(await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) or "").strip()
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Volcengine ASR final timeout ({}ms), fallback to last interim", timeout_ms)
|
||||
return str(self._last_interim_text or "").strip()
|
||||
finally:
|
||||
if self.protocol == "seed":
|
||||
await self._close_ws()
|
||||
|
||||
def clear_utterance(self) -> None:
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._last_interim_text = ""
|
||||
self._last_error = None
|
||||
self._seed_audio_buffer = bytearray()
|
||||
self._seed_sequence = 1
|
||||
self._seed_request_id = None
|
||||
self._drain_queue(self._final_queue)
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
while self._running:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _connect_gateway(self) -> None:
|
||||
assert self._session is not None
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
if self.resource_id:
|
||||
headers["X-Api-Resource-Id"] = self.resource_id
|
||||
|
||||
self._ws = await self._session.ws_connect(self.api_url, headers=headers, heartbeat=20)
|
||||
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||
await self._configure_gateway_session()
|
||||
|
||||
async def _configure_gateway_session(self) -> None:
|
||||
if not self._ws:
|
||||
raise RuntimeError("Volcengine ASR websocket is not initialized")
|
||||
|
||||
session_payload: Dict[str, Any] = {
|
||||
"input_audio_format": "pcm",
|
||||
"input_audio_codec": "raw",
|
||||
"input_audio_sample_rate": self.sample_rate,
|
||||
"input_audio_bits": 16,
|
||||
"input_audio_channel": 1,
|
||||
"result_type": 0,
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
}
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"type": "transcription_session.update",
|
||||
"session": session_payload,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._session_ready.wait(), timeout=8.0)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise RuntimeError("Volcengine ASR session update timeout") from exc
|
||||
|
||||
async def _open_seed_stream(self) -> None:
|
||||
if not self._session:
|
||||
raise RuntimeError("Volcengine ASR session is not initialized")
|
||||
|
||||
await self._close_ws()
|
||||
self._seed_request_id = uuid.uuid4().hex
|
||||
headers = self._build_seed_headers(self._seed_request_id)
|
||||
self._ws = await self._session.ws_connect(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
heartbeat=20,
|
||||
max_msg_size=1_000_000_000,
|
||||
)
|
||||
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||
await self._ws.send_bytes(self._build_seed_start_request())
|
||||
|
||||
async def _send_seed_audio(self, audio: bytes) -> None:
|
||||
if not self._utterance_active:
|
||||
await self.begin_utterance()
|
||||
if not self._ws:
|
||||
raise RuntimeError("Volcengine BigASR websocket is not connected")
|
||||
|
||||
self._seed_audio_buffer.extend(audio)
|
||||
while len(self._seed_audio_buffer) >= self._seed_frame_bytes:
|
||||
chunk = bytes(self._seed_audio_buffer[: self._seed_frame_bytes])
|
||||
del self._seed_audio_buffer[: self._seed_frame_bytes]
|
||||
self._seed_sequence += 1
|
||||
await self._ws.send_bytes(self._build_seed_audio_request(chunk, sequence=self._seed_sequence))
|
||||
self._audio_sent_in_utterance = True
|
||||
|
||||
async def _end_seed_utterance(self) -> None:
|
||||
if not self._ws:
|
||||
return
|
||||
if not self._audio_sent_in_utterance and not self._seed_audio_buffer:
|
||||
self._utterance_active = False
|
||||
return
|
||||
|
||||
final_chunk = bytes(self._seed_audio_buffer)
|
||||
self._seed_audio_buffer = bytearray()
|
||||
self._seed_sequence += 1
|
||||
await self._ws.send_bytes(
|
||||
self._build_seed_audio_request(final_chunk, sequence=-self._seed_sequence, is_last=True)
|
||||
)
|
||||
self._audio_sent_in_utterance = True
|
||||
self._utterance_active = False
|
||||
|
||||
async def _close_ws(self) -> None:
|
||||
reader_task = self._reader_task
|
||||
ws = self._ws
|
||||
self._reader_task = None
|
||||
self._ws = None
|
||||
|
||||
if reader_task:
|
||||
reader_task.cancel()
|
||||
try:
|
||||
await reader_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if ws is not None:
|
||||
await ws.close()
|
||||
|
||||
async def _reader_loop(self) -> None:
|
||||
ws = self._ws
|
||||
if ws is None:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
if self.protocol == "gateway":
|
||||
self._handle_gateway_event(msg.data)
|
||||
else:
|
||||
self._handle_seed_text(msg.data)
|
||||
continue
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
if self.protocol == "seed":
|
||||
self._handle_seed_binary(msg.data)
|
||||
continue
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._last_error = str(ws.exception())
|
||||
logger.error("Volcengine ASR websocket error: {}", self._last_error)
|
||||
break
|
||||
if msg.type in {aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE}:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._last_error = str(exc)
|
||||
logger.error("Volcengine ASR reader loop failed: {}", exc)
|
||||
finally:
|
||||
if self._ws is ws:
|
||||
self._ws = None
|
||||
|
||||
def _handle_gateway_event(self, message: str) -> None:
|
||||
payload = self._coerce_event(message)
|
||||
event_type = str(payload.get("type") or "").strip()
|
||||
if not event_type:
|
||||
return
|
||||
|
||||
if event_type in {"transcription_session.created", "transcription_session.updated"}:
|
||||
self._session_ready.set()
|
||||
return
|
||||
|
||||
if event_type == "error":
|
||||
self._last_error = self._extract_text(payload, ("message", "error"))
|
||||
logger.error("Volcengine ASR server error: {}", self._last_error or "unknown")
|
||||
return
|
||||
|
||||
if event_type.endswith(".failed"):
|
||||
self._last_error = self._extract_text(payload, ("message", "error", "transcript"))
|
||||
logger.error("Volcengine ASR failed event: {}", self._last_error or event_type)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.result":
|
||||
transcript = self._extract_text(payload, ("transcript", "result"))
|
||||
self._emit_transcript_sync(transcript, is_final=False)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.delta":
|
||||
transcript = self._extract_text(payload, ("delta",))
|
||||
self._emit_transcript_sync(transcript, is_final=False)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
transcript = self._extract_text(payload, ("transcript", "result"))
|
||||
self._emit_transcript_sync(transcript, is_final=True)
|
||||
|
||||
def _handle_seed_text(self, message: str) -> None:
|
||||
payload = self._coerce_event(message)
|
||||
if payload.get("type") == "error":
|
||||
self._last_error = self._extract_text(payload, ("message", "error"))
|
||||
logger.error("Volcengine BigASR error: {}", self._last_error or "unknown")
|
||||
|
||||
def _handle_seed_binary(self, message: bytes) -> None:
|
||||
payload = self._parse_seed_response(message)
|
||||
if payload.get("code"):
|
||||
self._last_error = self._extract_text(payload, ("payload_msg",))
|
||||
logger.error("Volcengine BigASR server error: {}", self._last_error or payload["code"])
|
||||
return
|
||||
|
||||
body = payload.get("payload_msg")
|
||||
if not isinstance(body, dict):
|
||||
return
|
||||
result = body.get("result")
|
||||
if not isinstance(result, dict):
|
||||
return
|
||||
|
||||
text = str(result.get("text") or "").strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
utterances = result.get("utterances")
|
||||
if not isinstance(utterances, list) or not utterances:
|
||||
return
|
||||
first_utterance = utterances[0] if isinstance(utterances[0], dict) else {}
|
||||
is_final = self._coerce_bool(first_utterance.get("definite")) is True
|
||||
self._emit_transcript_sync(text, is_final=is_final)
|
||||
|
||||
def _emit_transcript_sync(self, text: str, *, is_final: bool) -> None:
|
||||
cleaned = str(text or "").strip()
|
||||
if not cleaned:
|
||||
return
|
||||
|
||||
if not is_final:
|
||||
self._last_interim_text = cleaned
|
||||
else:
|
||||
self._last_interim_text = ""
|
||||
|
||||
result = ASRResult(text=cleaned, is_final=is_final)
|
||||
try:
|
||||
self._transcript_queue.put_nowait(result)
|
||||
except asyncio.QueueFull:
|
||||
logger.debug("Volcengine ASR transcript queue full; dropping transcript")
|
||||
|
||||
if is_final:
|
||||
try:
|
||||
self._final_queue.put_nowait(cleaned)
|
||||
except asyncio.QueueFull:
|
||||
logger.debug("Volcengine ASR final queue full; dropping transcript")
|
||||
|
||||
if self.on_transcript:
|
||||
asyncio.create_task(self.on_transcript(cleaned, is_final))
|
||||
|
||||
def _build_seed_headers(self, request_id: str) -> Dict[str, str]:
|
||||
if not self.app_id:
|
||||
raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.")
|
||||
if not self.api_key:
|
||||
raise ValueError("Volcengine ASR api_key not provided. Configure agent.asr.api_key in YAML.")
|
||||
|
||||
return {
|
||||
"X-Api-App-Key": str(self.app_id),
|
||||
"X-Api-Access-Key": str(self.api_key),
|
||||
"X-Api-Resource-Id": str(self.resource_id or self.DEFAULT_SEED_RESOURCE_ID),
|
||||
"X-Api-Request-Id": str(request_id),
|
||||
}
|
||||
|
||||
def _build_seed_start_payload(self) -> Dict[str, Any]:
|
||||
user_payload: Dict[str, Any] = {"uid": str(self.uid or self._seed_request_id or self.app_id or uuid.uuid4().hex)}
|
||||
audio_payload: Dict[str, Any] = {
|
||||
"format": "pcm",
|
||||
"rate": self.sample_rate,
|
||||
"bits": 16,
|
||||
"channels": 1,
|
||||
"codec": "raw",
|
||||
}
|
||||
if self.language and self.language != "auto":
|
||||
audio_payload["language"] = self.language
|
||||
|
||||
request_payload: Dict[str, Any] = {
|
||||
"model_name": self.model or self.DEFAULT_MODEL,
|
||||
"enable_itn": False,
|
||||
"enable_punc": True,
|
||||
"enable_ddc": False,
|
||||
"show_utterance": True,
|
||||
"result_type": "single",
|
||||
"vad_segment_duration": 3000,
|
||||
"end_window_size": 500,
|
||||
"force_to_speech_time": 1000,
|
||||
}
|
||||
|
||||
extra = dict(self.request_params)
|
||||
user_payload.update(self._as_dict(extra.pop("user", None)))
|
||||
audio_payload.update(self._as_dict(extra.pop("audio", None)))
|
||||
request_payload.update(self._as_dict(extra.pop("request", None)))
|
||||
request_payload.update(extra)
|
||||
|
||||
return {
|
||||
"user": user_payload,
|
||||
"audio": audio_payload,
|
||||
"request": request_payload,
|
||||
}
|
||||
|
||||
def _build_seed_start_request(self) -> bytes:
|
||||
payload = gzip.compress(json.dumps(self._build_seed_start_payload()).encode("utf-8"))
|
||||
frame = bytearray(
|
||||
self._build_seed_header(
|
||||
message_type=self._SEED_FULL_CLIENT_REQUEST,
|
||||
message_type_specific_flags=self._SEED_POS_SEQUENCE,
|
||||
)
|
||||
)
|
||||
frame.extend((1).to_bytes(4, "big", signed=True))
|
||||
frame.extend(len(payload).to_bytes(4, "big"))
|
||||
frame.extend(payload)
|
||||
return bytes(frame)
|
||||
|
||||
def _build_seed_audio_request(self, chunk: bytes, *, sequence: int, is_last: bool = False) -> bytes:
|
||||
payload = gzip.compress(chunk)
|
||||
frame = bytearray(
|
||||
self._build_seed_header(
|
||||
message_type=self._SEED_AUDIO_ONLY_REQUEST,
|
||||
message_type_specific_flags=self._SEED_NEG_WITH_SEQUENCE if is_last else self._SEED_POS_SEQUENCE,
|
||||
)
|
||||
)
|
||||
frame.extend(int(sequence).to_bytes(4, "big", signed=True))
|
||||
frame.extend(len(payload).to_bytes(4, "big"))
|
||||
frame.extend(payload)
|
||||
return bytes(frame)
|
||||
|
||||
@classmethod
|
||||
def _build_seed_header(
|
||||
cls,
|
||||
*,
|
||||
message_type: int,
|
||||
message_type_specific_flags: int,
|
||||
serial_method: int = _SEED_JSON,
|
||||
compression_type: int = _SEED_GZIP,
|
||||
reserved_data: int = 0x00,
|
||||
) -> bytes:
|
||||
header = bytearray()
|
||||
header.append((cls._SEED_PROTOCOL_VERSION << 4) | 0b0001)
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(reserved_data)
|
||||
return bytes(header)
|
||||
|
||||
@classmethod
|
||||
def _parse_seed_response(cls, response: bytes) -> Dict[str, Any]:
|
||||
header_size = response[0] & 0x0F
|
||||
message_type = response[1] >> 4
|
||||
message_type_specific_flags = response[1] & 0x0F
|
||||
serialization_method = response[2] >> 4
|
||||
compression_type = response[2] & 0x0F
|
||||
payload = response[header_size * 4 :]
|
||||
|
||||
result: Dict[str, Any] = {"is_last_package": False}
|
||||
payload_message: Any = None
|
||||
|
||||
if message_type_specific_flags & 0x01:
|
||||
result["payload_sequence"] = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload = payload[4:]
|
||||
|
||||
if message_type_specific_flags & 0x02:
|
||||
result["is_last_package"] = True
|
||||
|
||||
if message_type == cls._SEED_FULL_SERVER_RESPONSE:
|
||||
result["payload_size"] = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload_message = payload[4:]
|
||||
elif message_type == cls._SEED_SERVER_ACK:
|
||||
result["seq"] = int.from_bytes(payload[:4], "big", signed=True)
|
||||
if len(payload) >= 8:
|
||||
result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_message = payload[8:]
|
||||
elif message_type == cls._SEED_SERVER_ERROR_RESPONSE:
|
||||
result["code"] = int.from_bytes(payload[:4], "big", signed=False)
|
||||
result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_message = payload[8:]
|
||||
|
||||
if payload_message is None:
|
||||
return result
|
||||
if compression_type == cls._SEED_GZIP:
|
||||
payload_message = gzip.decompress(payload_message)
|
||||
if serialization_method == cls._SEED_JSON:
|
||||
payload_message = json.loads(payload_message.decode("utf-8"))
|
||||
elif serialization_method != cls._SEED_NO_SERIALIZATION:
|
||||
payload_message = payload_message.decode("utf-8")
|
||||
|
||||
result["payload_msg"] = payload_message
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _coerce_event(message: Any) -> Dict[str, Any]:
|
||||
if isinstance(message, dict):
|
||||
return message
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
loaded = json.loads(message)
|
||||
if isinstance(loaded, dict):
|
||||
return loaded
|
||||
except json.JSONDecodeError:
|
||||
return {"type": "raw", "message": message}
|
||||
return {"type": "raw", "message": str(message)}
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(payload: Dict[str, Any], keys: tuple[str, ...]) -> str:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
for nested_key in ("message", "text", "transcript", "result", "delta"):
|
||||
nested = value.get(nested_key)
|
||||
if isinstance(nested, str) and nested.strip():
|
||||
return nested.strip()
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _coerce_bool(value: Any) -> Optional[bool]:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if normalized in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _as_dict(value: Any) -> Dict[str, Any]:
|
||||
if isinstance(value, dict):
|
||||
return dict(value)
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
|
||||
while True:
|
||||
try:
|
||||
queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
@@ -17,14 +17,17 @@ from runtime.ports import (
|
||||
)
|
||||
from providers.asr.buffered import BufferedASRService
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
from providers.asr.volcengine import VolcengineRealtimeASRService
|
||||
from providers.tts.dashscope import DashScopeTTSService
|
||||
from providers.llm.openai import MockLLMService, OpenAILLMService
|
||||
from providers.asr.openai_compatible import OpenAICompatibleASRService
|
||||
from providers.tts.openai_compatible import OpenAICompatibleTTSService
|
||||
from providers.tts.mock import MockTTSService
|
||||
from providers.tts.volcengine import VolcengineTTSService
|
||||
|
||||
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
_DASHSCOPE_PROVIDERS = {"dashscope"}
|
||||
_VOLCENGINE_PROVIDERS = {"volcengine"}
|
||||
_SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS}
|
||||
|
||||
|
||||
@@ -37,6 +40,10 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
_DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime"
|
||||
_DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
||||
_DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||||
_DEFAULT_VOLCENGINE_TTS_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
|
||||
_DEFAULT_VOLCENGINE_TTS_RESOURCE_ID = "seed-tts-2.0"
|
||||
_DEFAULT_VOLCENGINE_ASR_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
|
||||
_DEFAULT_VOLCENGINE_ASR_MODEL = "bigmodel"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_provider(provider: Any) -> str:
|
||||
@@ -81,6 +88,19 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
speed=spec.speed,
|
||||
)
|
||||
|
||||
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
|
||||
return VolcengineTTSService(
|
||||
api_key=spec.api_key,
|
||||
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_TTS_URL,
|
||||
voice=spec.voice,
|
||||
model=spec.model,
|
||||
app_id=spec.app_id,
|
||||
resource_id=spec.resource_id or self._DEFAULT_VOLCENGINE_TTS_RESOURCE_ID,
|
||||
uid=spec.uid,
|
||||
sample_rate=spec.sample_rate,
|
||||
speed=spec.speed,
|
||||
)
|
||||
|
||||
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
|
||||
return OpenAICompatibleTTSService(
|
||||
api_key=spec.api_key,
|
||||
@@ -110,6 +130,20 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
on_transcript=spec.on_transcript,
|
||||
)
|
||||
|
||||
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
|
||||
return VolcengineRealtimeASRService(
|
||||
api_key=spec.api_key,
|
||||
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_ASR_REALTIME_URL,
|
||||
model=spec.model or self._DEFAULT_VOLCENGINE_ASR_MODEL,
|
||||
sample_rate=spec.sample_rate,
|
||||
language=spec.language,
|
||||
app_id=spec.app_id,
|
||||
resource_id=spec.resource_id,
|
||||
uid=spec.uid,
|
||||
request_params=spec.request_params,
|
||||
on_transcript=spec.on_transcript,
|
||||
)
|
||||
|
||||
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
|
||||
return OpenAICompatibleASRService(
|
||||
api_key=spec.api_key,
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
"""TTS providers."""
|
||||
|
||||
from providers.tts.volcengine import VolcengineTTSService
|
||||
|
||||
__all__ = ["VolcengineTTSService"]
|
||||
|
||||
219
engine/providers/tts/volcengine.py
Normal file
219
engine/providers/tts/volcengine.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Volcengine TTS service.
|
||||
|
||||
Uses Volcengine's unidirectional HTTP streaming TTS API and adapts streamed
|
||||
base64 audio chunks into engine-native ``TTSChunk`` events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import codecs
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseTTSService, ServiceState, TTSChunk
|
||||
|
||||
|
||||
class VolcengineTTSService(BaseTTSService):
|
||||
"""Streaming TTS adapter for Volcengine's HTTP v3 API."""
|
||||
|
||||
DEFAULT_API_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
|
||||
DEFAULT_RESOURCE_ID = "seed-tts-2.0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "zh_female_shuangkuaisisi_moon_bigtts",
|
||||
model: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
uid: Optional[str] = None,
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
self.api_key = api_key or os.getenv("VOLCENGINE_TTS_API_KEY") or os.getenv("TTS_API_KEY")
|
||||
self.api_url = api_url or os.getenv("VOLCENGINE_TTS_API_URL") or self.DEFAULT_API_URL
|
||||
self.model = str(model or os.getenv("VOLCENGINE_TTS_MODEL") or "").strip() or None
|
||||
self.app_id = app_id or os.getenv("VOLCENGINE_TTS_APP_ID") or os.getenv("TTS_APP_ID")
|
||||
self.resource_id = resource_id or os.getenv("VOLCENGINE_TTS_RESOURCE_ID") or self.DEFAULT_RESOURCE_ID
|
||||
self.uid = uid or os.getenv("VOLCENGINE_TTS_UID")
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._synthesis_lock = asyncio.Lock()
|
||||
self._pending_audio: list[bytes] = []
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not self.api_key:
|
||||
raise ValueError("Volcengine TTS API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
if not self.app_id:
|
||||
raise ValueError("Volcengine TTS app_id not provided. Configure agent.tts.app_id in YAML.")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(
|
||||
"Volcengine TTS service ready: speaker={}, sample_rate={}, resource_id={}",
|
||||
self.voice,
|
||||
self.sample_rate,
|
||||
self.resource_id,
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Volcengine TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
audio = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio += chunk.audio
|
||||
return audio
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
if not self._session:
|
||||
raise RuntimeError("Volcengine TTS service not connected")
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
async with self._synthesis_lock:
|
||||
self._cancel_event.clear()
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Api-App-Key": str(self.app_id),
|
||||
"X-Api-Access-Key": str(self.api_key),
|
||||
"X-Api-Resource-Id": str(self.resource_id),
|
||||
"X-Api-Request-Id": str(uuid.uuid4()),
|
||||
}
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": str(self.uid or self.app_id),
|
||||
},
|
||||
"req_params": {
|
||||
"text": text,
|
||||
"speaker": self.voice,
|
||||
"audio_params": {
|
||||
"format": "pcm",
|
||||
"sample_rate": self.sample_rate,
|
||||
"speech_rate": self._speech_rate_percent(self.speed),
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.model:
|
||||
payload["req_params"]["model"] = self.model
|
||||
|
||||
chunk_size = max(1, self.sample_rate * 2 // 10)
|
||||
audio_buffer = b""
|
||||
pending_chunk: Optional[bytes] = None
|
||||
|
||||
try:
|
||||
async with self._session.post(self.api_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Volcengine TTS error {response.status}: {error_text}")
|
||||
|
||||
async for audio_bytes in self._iter_audio_bytes(response):
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("Volcengine TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
audio_buffer += audio_bytes
|
||||
while len(audio_buffer) >= chunk_size:
|
||||
emitted = audio_buffer[:chunk_size]
|
||||
audio_buffer = audio_buffer[chunk_size:]
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = emitted
|
||||
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
if pending_chunk is not None:
|
||||
if audio_buffer:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
|
||||
pending_chunk = None
|
||||
|
||||
if audio_buffer:
|
||||
yield TTSChunk(audio=audio_buffer, sample_rate=self.sample_rate, is_final=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Volcengine TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("Volcengine TTS synthesis error: {}", exc)
|
||||
raise
|
||||
|
||||
async def cancel(self) -> None:
|
||||
self._cancel_event.set()
|
||||
|
||||
async def _iter_audio_bytes(self, response: aiohttp.ClientResponse) -> AsyncIterator[bytes]:
|
||||
decoder = json.JSONDecoder()
|
||||
utf8_decoder = codecs.getincrementaldecoder("utf-8")()
|
||||
text_buffer = ""
|
||||
self._pending_audio.clear()
|
||||
|
||||
async for raw_chunk in response.content.iter_any():
|
||||
text_buffer += utf8_decoder.decode(raw_chunk)
|
||||
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
|
||||
while self._pending_audio:
|
||||
yield self._pending_audio.pop(0)
|
||||
|
||||
text_buffer += utf8_decoder.decode(b"", final=True)
|
||||
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
|
||||
while self._pending_audio:
|
||||
yield self._pending_audio.pop(0)
|
||||
|
||||
def _yield_audio_payloads(self, decoder: json.JSONDecoder, text_buffer: str) -> str:
|
||||
while True:
|
||||
stripped = text_buffer.lstrip()
|
||||
if not stripped:
|
||||
return ""
|
||||
if len(stripped) != len(text_buffer):
|
||||
text_buffer = stripped
|
||||
|
||||
try:
|
||||
payload, idx = decoder.raw_decode(text_buffer)
|
||||
except json.JSONDecodeError:
|
||||
return text_buffer
|
||||
|
||||
text_buffer = text_buffer[idx:]
|
||||
audio = self._extract_audio_bytes(payload)
|
||||
if audio:
|
||||
self._pending_audio.append(audio)
|
||||
|
||||
def _extract_audio_bytes(self, payload: Any) -> bytes:
|
||||
if not isinstance(payload, dict):
|
||||
return b""
|
||||
|
||||
code = payload.get("code")
|
||||
if code not in (None, 0, 20000000):
|
||||
message = str(payload.get("message") or "unknown error")
|
||||
raise RuntimeError(f"Volcengine TTS stream error {code}: {message}")
|
||||
|
||||
encoded = payload.get("data")
|
||||
if isinstance(encoded, str) and encoded.strip():
|
||||
try:
|
||||
return base64.b64decode(encoded)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to decode Volcengine TTS audio chunk: {}", exc)
|
||||
return b""
|
||||
|
||||
@staticmethod
|
||||
def _speech_rate_percent(speed: float) -> int:
|
||||
clamped = max(0.5, min(2.0, float(speed or 1.0)))
|
||||
return int(round((clamped - 1.0) * 100))
|
||||
Reference in New Issue
Block a user