"""Volcengine realtime ASR service. Supports both: - Volcengine Edge Gateway realtime transcription websocket, and - Volcengine BigASR Seed websocket at openspeech.bytedance.com/api/v3/sauc/bigmodel. """ from __future__ import annotations import asyncio import base64 import gzip import json import os import uuid from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse import aiohttp from loguru import logger from providers.common.base import ASRResult, BaseASRService, ServiceState VolcengineASRProtocol = Literal["gateway", "seed"] class VolcengineRealtimeASRService(BaseASRService): """Realtime streaming ASR backed by Volcengine websocket APIs.""" DEFAULT_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel" DEFAULT_GATEWAY_WS_URL = "wss://ai-gateway.vei.volces.com/v1/realtime" DEFAULT_MODEL = "bigmodel" DEFAULT_FINAL_TIMEOUT_MS = 1200 DEFAULT_SEED_RESOURCE_ID = "volc.bigasr.sauc.duration" _SEED_FRAME_MS = 100 _SEED_PROTOCOL_VERSION = 0b0001 _SEED_FULL_CLIENT_REQUEST = 0b0001 _SEED_AUDIO_ONLY_REQUEST = 0b0010 _SEED_FULL_SERVER_RESPONSE = 0b1001 _SEED_SERVER_ACK = 0b1011 _SEED_SERVER_ERROR_RESPONSE = 0b1111 _SEED_NO_SEQUENCE = 0b0000 _SEED_POS_SEQUENCE = 0b0001 _SEED_NEG_WITH_SEQUENCE = 0b0011 _SEED_NO_SERIALIZATION = 0b0000 _SEED_JSON = 0b0001 _SEED_NO_COMPRESSION = 0b0000 _SEED_GZIP = 0b0001 def __init__( self, api_key: Optional[str] = None, api_url: Optional[str] = None, model: Optional[str] = None, sample_rate: int = 16000, language: str = "auto", app_id: Optional[str] = None, resource_id: Optional[str] = None, uid: Optional[str] = None, request_params: Optional[Dict[str, Any]] = None, on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None, ) -> None: super().__init__(sample_rate=sample_rate, language=language) self.mode = "streaming" self.api_key = api_key or os.getenv("VOLCENGINE_ASR_API_KEY") or os.getenv("ASR_API_KEY") self.model = str(model or os.getenv("VOLCENGINE_ASR_MODEL") or self.DEFAULT_MODEL).strip() raw_api_url = api_url or os.getenv("VOLCENGINE_ASR_API_URL") or self.DEFAULT_WS_URL self.protocol = self._detect_protocol(raw_api_url) self.api_url = self._resolve_api_url(raw_api_url, self.model, self.protocol) self.app_id = app_id or os.getenv("VOLCENGINE_ASR_APP_ID") or os.getenv("ASR_APP_ID") self.resource_id = ( resource_id or os.getenv("VOLCENGINE_ASR_RESOURCE_ID") or (self.DEFAULT_SEED_RESOURCE_ID if self.protocol == "seed" else None) ) self.uid = uid or os.getenv("VOLCENGINE_ASR_UID") self.request_params = self._load_request_params(request_params) self.on_transcript = on_transcript self._session: Optional[aiohttp.ClientSession] = None self._ws: Optional[aiohttp.ClientWebSocketResponse] = None self._reader_task: Optional[asyncio.Task[None]] = None self._running = False self._session_ready = asyncio.Event() self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue() self._final_queue: "asyncio.Queue[str]" = asyncio.Queue() self._utterance_active = False self._audio_sent_in_utterance = False self._last_interim_text = "" self._last_error: Optional[str] = None self._seed_audio_buffer = bytearray() self._seed_sequence = 1 self._seed_request_id: Optional[str] = None self._seed_frame_bytes = max(2, int((self.sample_rate * self._SEED_FRAME_MS / 1000) * 2)) @classmethod def _detect_protocol(cls, api_url: str) -> VolcengineASRProtocol: parsed = urlparse(str(api_url or "").strip()) host = parsed.netloc.lower() path = parsed.path.lower() if "openspeech.bytedance.com" in host and "/api/v3/sauc/bigmodel" in path: return "seed" return "gateway" @classmethod def _resolve_api_url(cls, api_url: str, model: str, protocol: VolcengineASRProtocol) -> str: raw = str(api_url or "").strip() if not raw: raw = cls.DEFAULT_WS_URL if protocol == "seed" else cls.DEFAULT_GATEWAY_WS_URL if protocol != "gateway": return raw parsed = urlparse(raw) query = dict(parse_qsl(parsed.query, keep_blank_values=True)) query.setdefault("model", model or cls.DEFAULT_MODEL) return urlunparse(parsed._replace(query=urlencode(query))) @staticmethod def _load_request_params(request_params: Optional[Dict[str, Any]]) -> Dict[str, Any]: if isinstance(request_params, dict): return dict(request_params) raw = os.getenv("VOLCENGINE_ASR_REQUEST_PARAMS_JSON", "").strip() if not raw: return {} try: parsed = json.loads(raw) except json.JSONDecodeError: logger.warning("Ignoring invalid VOLCENGINE_ASR_REQUEST_PARAMS_JSON") return {} if isinstance(parsed, dict): return parsed return {} async def connect(self) -> None: if not self.api_key: raise ValueError("Volcengine ASR API key not provided. Configure agent.asr.api_key in YAML.") timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15) self._session = aiohttp.ClientSession(timeout=timeout) self._running = True if self.protocol == "gateway": await self._connect_gateway() logger.info( "Volcengine gateway ASR connected: model={}, sample_rate={}, url={}", self.model, self.sample_rate, self.api_url, ) else: if not self.app_id: raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.") logger.info( "Volcengine BigASR Seed ready: model={}, sample_rate={}, resource_id={}", self.model, self.sample_rate, self.resource_id, ) self.state = ServiceState.CONNECTED async def disconnect(self) -> None: self._running = False self._utterance_active = False self._audio_sent_in_utterance = False self._session_ready.clear() self._seed_audio_buffer = bytearray() self._drain_queue(self._final_queue) self._drain_queue(self._transcript_queue) await self._close_ws() if self._session is not None: await self._session.close() self._session = None self.state = ServiceState.DISCONNECTED logger.info("Volcengine ASR disconnected") async def begin_utterance(self) -> None: self.clear_utterance() if self.protocol == "seed": await self._open_seed_stream() self._utterance_active = True async def send_audio(self, audio: bytes) -> None: if not audio: return if self.protocol == "seed": await self._send_seed_audio(audio) return if not self._ws: raise RuntimeError("Volcengine ASR websocket is not connected") if not self._utterance_active: self._utterance_active = True await self._ws.send_json( { "type": "input_audio_buffer.append", "audio": base64.b64encode(audio).decode("ascii"), } ) self._audio_sent_in_utterance = True async def end_utterance(self) -> None: if not self._utterance_active: return if self.protocol == "seed": await self._end_seed_utterance() return if not self._ws or not self._audio_sent_in_utterance: return await self._ws.send_json({"type": "input_audio_buffer.commit"}) self._utterance_active = False async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str: if not self._audio_sent_in_utterance: return "" timeout_sec = max(0.05, float(timeout_ms) / 1000.0) try: return str(await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) or "").strip() except asyncio.TimeoutError: logger.debug("Volcengine ASR final timeout ({}ms), fallback to last interim", timeout_ms) return str(self._last_interim_text or "").strip() finally: if self.protocol == "seed": await self._close_ws() def clear_utterance(self) -> None: self._utterance_active = False self._audio_sent_in_utterance = False self._last_interim_text = "" self._last_error = None self._seed_audio_buffer = bytearray() self._seed_sequence = 1 self._seed_request_id = None self._drain_queue(self._final_queue) async def receive_transcripts(self) -> AsyncIterator[ASRResult]: while self._running: try: yield await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1) except asyncio.TimeoutError: continue except asyncio.CancelledError: break async def _connect_gateway(self) -> None: assert self._session is not None headers = {"Authorization": f"Bearer {self.api_key}"} if self.resource_id: headers["X-Api-Resource-Id"] = self.resource_id self._ws = await self._session.ws_connect(self.api_url, headers=headers, heartbeat=20) self._reader_task = asyncio.create_task(self._reader_loop()) await self._configure_gateway_session() async def _configure_gateway_session(self) -> None: if not self._ws: raise RuntimeError("Volcengine ASR websocket is not initialized") session_payload: Dict[str, Any] = { "input_audio_format": "pcm", "input_audio_codec": "raw", "input_audio_sample_rate": self.sample_rate, "input_audio_bits": 16, "input_audio_channel": 1, "result_type": 0, "input_audio_transcription": { "model": self.model, }, } await self._ws.send_json( { "type": "transcription_session.update", "session": session_payload, } ) try: await asyncio.wait_for(self._session_ready.wait(), timeout=8.0) except asyncio.TimeoutError as exc: raise RuntimeError("Volcengine ASR session update timeout") from exc async def _open_seed_stream(self) -> None: if not self._session: raise RuntimeError("Volcengine ASR session is not initialized") await self._close_ws() self._seed_request_id = uuid.uuid4().hex headers = self._build_seed_headers(self._seed_request_id) self._ws = await self._session.ws_connect( self.api_url, headers=headers, heartbeat=20, max_msg_size=1_000_000_000, ) self._reader_task = asyncio.create_task(self._reader_loop()) await self._ws.send_bytes(self._build_seed_start_request()) async def _send_seed_audio(self, audio: bytes) -> None: if not self._utterance_active: await self.begin_utterance() if not self._ws: raise RuntimeError("Volcengine BigASR websocket is not connected") self._seed_audio_buffer.extend(audio) while len(self._seed_audio_buffer) >= self._seed_frame_bytes: chunk = bytes(self._seed_audio_buffer[: self._seed_frame_bytes]) del self._seed_audio_buffer[: self._seed_frame_bytes] self._seed_sequence += 1 await self._ws.send_bytes(self._build_seed_audio_request(chunk, sequence=self._seed_sequence)) self._audio_sent_in_utterance = True async def _end_seed_utterance(self) -> None: if not self._ws: return if not self._audio_sent_in_utterance and not self._seed_audio_buffer: self._utterance_active = False return final_chunk = bytes(self._seed_audio_buffer) self._seed_audio_buffer = bytearray() self._seed_sequence += 1 await self._ws.send_bytes( self._build_seed_audio_request(final_chunk, sequence=-self._seed_sequence, is_last=True) ) self._audio_sent_in_utterance = True self._utterance_active = False async def _close_ws(self) -> None: reader_task = self._reader_task ws = self._ws self._reader_task = None self._ws = None if reader_task: reader_task.cancel() try: await reader_task except asyncio.CancelledError: pass if ws is not None: await ws.close() async def _reader_loop(self) -> None: ws = self._ws if ws is None: return try: async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: if self.protocol == "gateway": self._handle_gateway_event(msg.data) else: self._handle_seed_text(msg.data) continue if msg.type == aiohttp.WSMsgType.BINARY: if self.protocol == "seed": self._handle_seed_binary(msg.data) continue if msg.type == aiohttp.WSMsgType.ERROR: self._last_error = str(ws.exception()) logger.error("Volcengine ASR websocket error: {}", self._last_error) break if msg.type in {aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE}: break except asyncio.CancelledError: raise except Exception as exc: self._last_error = str(exc) logger.error("Volcengine ASR reader loop failed: {}", exc) finally: if self._ws is ws: self._ws = None def _handle_gateway_event(self, message: str) -> None: payload = self._coerce_event(message) event_type = str(payload.get("type") or "").strip() if not event_type: return if event_type in {"transcription_session.created", "transcription_session.updated"}: self._session_ready.set() return if event_type == "error": self._last_error = self._extract_text(payload, ("message", "error")) logger.error("Volcengine ASR server error: {}", self._last_error or "unknown") return if event_type.endswith(".failed"): self._last_error = self._extract_text(payload, ("message", "error", "transcript")) logger.error("Volcengine ASR failed event: {}", self._last_error or event_type) return if event_type == "conversation.item.input_audio_transcription.result": transcript = self._extract_text(payload, ("transcript", "result")) self._emit_transcript_sync(transcript, is_final=False) return if event_type == "conversation.item.input_audio_transcription.delta": transcript = self._extract_text(payload, ("delta",)) self._emit_transcript_sync(transcript, is_final=False) return if event_type == "conversation.item.input_audio_transcription.completed": transcript = self._extract_text(payload, ("transcript", "result")) self._emit_transcript_sync(transcript, is_final=True) def _handle_seed_text(self, message: str) -> None: payload = self._coerce_event(message) if payload.get("type") == "error": self._last_error = self._extract_text(payload, ("message", "error")) logger.error("Volcengine BigASR error: {}", self._last_error or "unknown") def _handle_seed_binary(self, message: bytes) -> None: payload = self._parse_seed_response(message) if payload.get("code"): self._last_error = self._extract_text(payload, ("payload_msg",)) logger.error("Volcengine BigASR server error: {}", self._last_error or payload["code"]) return body = payload.get("payload_msg") if not isinstance(body, dict): return result = body.get("result") if not isinstance(result, dict): return text = str(result.get("text") or "").strip() if not text: return utterances = result.get("utterances") if not isinstance(utterances, list) or not utterances: return first_utterance = utterances[0] if isinstance(utterances[0], dict) else {} is_final = self._coerce_bool(first_utterance.get("definite")) is True self._emit_transcript_sync(text, is_final=is_final) def _emit_transcript_sync(self, text: str, *, is_final: bool) -> None: cleaned = str(text or "").strip() if not cleaned: return if not is_final: self._last_interim_text = cleaned else: self._last_interim_text = "" result = ASRResult(text=cleaned, is_final=is_final) try: self._transcript_queue.put_nowait(result) except asyncio.QueueFull: logger.debug("Volcengine ASR transcript queue full; dropping transcript") if is_final: try: self._final_queue.put_nowait(cleaned) except asyncio.QueueFull: logger.debug("Volcengine ASR final queue full; dropping transcript") if self.on_transcript: asyncio.create_task(self.on_transcript(cleaned, is_final)) def _build_seed_headers(self, request_id: str) -> Dict[str, str]: if not self.app_id: raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.") if not self.api_key: raise ValueError("Volcengine ASR api_key not provided. Configure agent.asr.api_key in YAML.") return { "X-Api-App-Key": str(self.app_id), "X-Api-Access-Key": str(self.api_key), "X-Api-Resource-Id": str(self.resource_id or self.DEFAULT_SEED_RESOURCE_ID), "X-Api-Request-Id": str(request_id), } def _build_seed_start_payload(self) -> Dict[str, Any]: user_payload: Dict[str, Any] = {"uid": str(self.uid or self._seed_request_id or self.app_id or uuid.uuid4().hex)} audio_payload: Dict[str, Any] = { "format": "pcm", "rate": self.sample_rate, "bits": 16, "channels": 1, "codec": "raw", } if self.language and self.language != "auto": audio_payload["language"] = self.language request_payload: Dict[str, Any] = { "model_name": self.model or self.DEFAULT_MODEL, "enable_itn": False, "enable_punc": True, "enable_ddc": False, "show_utterance": True, "result_type": "single", "vad_segment_duration": 3000, "end_window_size": 500, "force_to_speech_time": 1000, } extra = dict(self.request_params) user_payload.update(self._as_dict(extra.pop("user", None))) audio_payload.update(self._as_dict(extra.pop("audio", None))) request_payload.update(self._as_dict(extra.pop("request", None))) request_payload.update(extra) return { "user": user_payload, "audio": audio_payload, "request": request_payload, } def _build_seed_start_request(self) -> bytes: payload = gzip.compress(json.dumps(self._build_seed_start_payload()).encode("utf-8")) frame = bytearray( self._build_seed_header( message_type=self._SEED_FULL_CLIENT_REQUEST, message_type_specific_flags=self._SEED_POS_SEQUENCE, ) ) frame.extend((1).to_bytes(4, "big", signed=True)) frame.extend(len(payload).to_bytes(4, "big")) frame.extend(payload) return bytes(frame) def _build_seed_audio_request(self, chunk: bytes, *, sequence: int, is_last: bool = False) -> bytes: payload = gzip.compress(chunk) frame = bytearray( self._build_seed_header( message_type=self._SEED_AUDIO_ONLY_REQUEST, message_type_specific_flags=self._SEED_NEG_WITH_SEQUENCE if is_last else self._SEED_POS_SEQUENCE, ) ) frame.extend(int(sequence).to_bytes(4, "big", signed=True)) frame.extend(len(payload).to_bytes(4, "big")) frame.extend(payload) return bytes(frame) @classmethod def _build_seed_header( cls, *, message_type: int, message_type_specific_flags: int, serial_method: int = _SEED_JSON, compression_type: int = _SEED_GZIP, reserved_data: int = 0x00, ) -> bytes: header = bytearray() header.append((cls._SEED_PROTOCOL_VERSION << 4) | 0b0001) header.append((message_type << 4) | message_type_specific_flags) header.append((serial_method << 4) | compression_type) header.append(reserved_data) return bytes(header) @classmethod def _parse_seed_response(cls, response: bytes) -> Dict[str, Any]: header_size = response[0] & 0x0F message_type = response[1] >> 4 message_type_specific_flags = response[1] & 0x0F serialization_method = response[2] >> 4 compression_type = response[2] & 0x0F payload = response[header_size * 4 :] result: Dict[str, Any] = {"is_last_package": False} payload_message: Any = None if message_type_specific_flags & 0x01: result["payload_sequence"] = int.from_bytes(payload[:4], "big", signed=True) payload = payload[4:] if message_type_specific_flags & 0x02: result["is_last_package"] = True if message_type == cls._SEED_FULL_SERVER_RESPONSE: result["payload_size"] = int.from_bytes(payload[:4], "big", signed=True) payload_message = payload[4:] elif message_type == cls._SEED_SERVER_ACK: result["seq"] = int.from_bytes(payload[:4], "big", signed=True) if len(payload) >= 8: result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False) payload_message = payload[8:] elif message_type == cls._SEED_SERVER_ERROR_RESPONSE: result["code"] = int.from_bytes(payload[:4], "big", signed=False) result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False) payload_message = payload[8:] if payload_message is None: return result if compression_type == cls._SEED_GZIP: payload_message = gzip.decompress(payload_message) if serialization_method == cls._SEED_JSON: payload_message = json.loads(payload_message.decode("utf-8")) elif serialization_method != cls._SEED_NO_SERIALIZATION: payload_message = payload_message.decode("utf-8") result["payload_msg"] = payload_message return result @staticmethod def _coerce_event(message: Any) -> Dict[str, Any]: if isinstance(message, dict): return message if isinstance(message, str): try: loaded = json.loads(message) if isinstance(loaded, dict): return loaded except json.JSONDecodeError: return {"type": "raw", "message": message} return {"type": "raw", "message": str(message)} @staticmethod def _extract_text(payload: Dict[str, Any], keys: tuple[str, ...]) -> str: for key in keys: value = payload.get(key) if isinstance(value, str) and value.strip(): return value.strip() if isinstance(value, dict): for nested_key in ("message", "text", "transcript", "result", "delta"): nested = value.get(nested_key) if isinstance(nested, str) and nested.strip(): return nested.strip() return "" @staticmethod def _coerce_bool(value: Any) -> Optional[bool]: if isinstance(value, bool): return value if isinstance(value, (int, float)): return bool(value) if isinstance(value, str): normalized = value.strip().lower() if normalized in {"1", "true", "yes", "on"}: return True if normalized in {"0", "false", "no", "off"}: return False return None @staticmethod def _as_dict(value: Any) -> Dict[str, Any]: if isinstance(value, dict): return dict(value) return {} @staticmethod def _drain_queue(queue: "asyncio.Queue[Any]") -> None: while True: try: queue.get_nowait() except asyncio.QueueEmpty: break