- Added DashScope ASR service implementation for real-time streaming. - Updated ASR provider logic to support DashScope alongside existing providers. - Enhanced runtime metadata resolution to include DashScope as a valid ASR provider. - Modified configuration files and documentation to reflect the addition of DashScope. - Introduced tests to validate DashScope integration and ASR service behavior. - Refactored ASR service factory to accommodate new provider options and modes.
389 lines
14 KiB
Python
389 lines
14 KiB
Python
"""DashScope realtime streaming ASR service.
|
|
|
|
Uses Qwen-ASR-Realtime via DashScope Python SDK.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional
|
|
|
|
from loguru import logger
|
|
|
|
from providers.common.base import ASRResult, BaseASRService, ServiceState
|
|
|
|
try:
|
|
import dashscope
|
|
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
|
|
|
|
# Some SDK builds keep TranscriptionParams under qwen_omni.omni_realtime.
|
|
try:
|
|
from dashscope.audio.qwen_omni import TranscriptionParams
|
|
except ImportError:
|
|
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
|
|
|
|
DASHSCOPE_SDK_AVAILABLE = True
|
|
DASHSCOPE_IMPORT_ERROR = ""
|
|
except Exception as exc:
|
|
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
|
|
dashscope = None # type: ignore[assignment]
|
|
MultiModality = None # type: ignore[assignment]
|
|
OmniRealtimeConversation = None # type: ignore[assignment]
|
|
TranscriptionParams = None # type: ignore[assignment]
|
|
DASHSCOPE_SDK_AVAILABLE = False
|
|
|
|
class OmniRealtimeCallback: # type: ignore[no-redef]
|
|
"""Fallback callback base when DashScope SDK is unavailable."""
|
|
|
|
pass
|
|
|
|
|
|
class _DashScopeASRCallback(OmniRealtimeCallback):
|
|
"""Bridge DashScope SDK callbacks into asyncio loop-safe handlers."""
|
|
|
|
def __init__(self, owner: "DashScopeRealtimeASRService", loop: asyncio.AbstractEventLoop):
|
|
super().__init__()
|
|
self._owner = owner
|
|
self._loop = loop
|
|
|
|
def _schedule(self, fn: Callable[[], None]) -> None:
|
|
try:
|
|
self._loop.call_soon_threadsafe(fn)
|
|
except RuntimeError:
|
|
return
|
|
|
|
def on_open(self) -> None:
|
|
self._schedule(self._owner._on_ws_open)
|
|
|
|
def on_close(self, code: int, msg: str) -> None:
|
|
self._schedule(lambda: self._owner._on_ws_close(code, msg))
|
|
|
|
def on_event(self, message: Any) -> None:
|
|
self._schedule(lambda: self._owner._on_ws_event(message))
|
|
|
|
def on_error(self, message: Any) -> None:
|
|
self._schedule(lambda: self._owner._on_ws_error(message))
|
|
|
|
|
|
class DashScopeRealtimeASRService(BaseASRService):
|
|
"""Realtime streaming ASR implementation for DashScope Qwen-ASR-Realtime."""
|
|
|
|
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
|
DEFAULT_MODEL = "qwen3-asr-flash-realtime"
|
|
DEFAULT_FINAL_TIMEOUT_MS = 800
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
api_url: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
sample_rate: int = 16000,
|
|
language: str = "auto",
|
|
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
|
) -> None:
|
|
super().__init__(sample_rate=sample_rate, language=language)
|
|
self.mode = "streaming"
|
|
self.api_key = (
|
|
api_key
|
|
or os.getenv("DASHSCOPE_API_KEY")
|
|
or os.getenv("ASR_API_KEY")
|
|
)
|
|
self.api_url = api_url or os.getenv("DASHSCOPE_ASR_API_URL") or self.DEFAULT_WS_URL
|
|
self.model = model or os.getenv("DASHSCOPE_ASR_MODEL") or self.DEFAULT_MODEL
|
|
self.on_transcript = on_transcript
|
|
|
|
self._client: Optional[Any] = None
|
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self._callback: Optional[_DashScopeASRCallback] = None
|
|
|
|
self._running = False
|
|
self._session_ready = asyncio.Event()
|
|
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
|
|
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
|
|
|
|
self._utterance_active = False
|
|
self._audio_sent_in_utterance = False
|
|
self._last_interim_text = ""
|
|
self._last_error: Optional[str] = None
|
|
|
|
async def connect(self) -> None:
|
|
if not DASHSCOPE_SDK_AVAILABLE:
|
|
py_exec = sys.executable
|
|
hint = f"`{py_exec} -m pip install dashscope>=1.25.6`"
|
|
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
|
raise RuntimeError(
|
|
f"dashscope SDK unavailable in interpreter {py_exec}; install with {hint}{detail}"
|
|
)
|
|
if not self.api_key:
|
|
raise ValueError("DashScope ASR API key not provided. Configure agent.asr.api_key in YAML.")
|
|
|
|
self._loop = asyncio.get_running_loop()
|
|
self._callback = _DashScopeASRCallback(owner=self, loop=self._loop)
|
|
|
|
if dashscope is not None:
|
|
dashscope.api_key = self.api_key
|
|
|
|
self._client = OmniRealtimeConversation( # type: ignore[misc]
|
|
model=self.model,
|
|
url=self.api_url,
|
|
callback=self._callback,
|
|
)
|
|
await asyncio.to_thread(self._client.connect)
|
|
await self._configure_session()
|
|
|
|
self._running = True
|
|
self.state = ServiceState.CONNECTED
|
|
logger.info(
|
|
"DashScope realtime ASR connected: model={}, sample_rate={}, language={}",
|
|
self.model,
|
|
self.sample_rate,
|
|
self.language,
|
|
)
|
|
|
|
async def disconnect(self) -> None:
|
|
self._running = False
|
|
self._utterance_active = False
|
|
self._audio_sent_in_utterance = False
|
|
self._drain_queue(self._final_queue)
|
|
self._drain_queue(self._transcript_queue)
|
|
self._session_ready.clear()
|
|
|
|
if self._client is not None:
|
|
close_fn = getattr(self._client, "close", None)
|
|
if callable(close_fn):
|
|
await asyncio.to_thread(close_fn)
|
|
self._client = None
|
|
self.state = ServiceState.DISCONNECTED
|
|
logger.info("DashScope realtime ASR disconnected")
|
|
|
|
async def begin_utterance(self) -> None:
|
|
self.clear_utterance()
|
|
self._utterance_active = True
|
|
|
|
async def send_audio(self, audio: bytes) -> None:
|
|
if not self._client:
|
|
raise RuntimeError("DashScope ASR service not connected")
|
|
if not audio:
|
|
return
|
|
|
|
if not self._utterance_active:
|
|
# Allow graceful fallback if caller sends before begin_utterance.
|
|
self._utterance_active = True
|
|
|
|
audio_b64 = base64.b64encode(audio).decode("ascii")
|
|
append_fn = getattr(self._client, "append_audio", None)
|
|
if not callable(append_fn):
|
|
raise RuntimeError("DashScope ASR SDK missing append_audio method")
|
|
await asyncio.to_thread(append_fn, audio_b64)
|
|
self._audio_sent_in_utterance = True
|
|
|
|
async def end_utterance(self) -> None:
|
|
if not self._client:
|
|
return
|
|
if not self._utterance_active or not self._audio_sent_in_utterance:
|
|
return
|
|
|
|
commit_fn = getattr(self._client, "commit", None)
|
|
if not callable(commit_fn):
|
|
raise RuntimeError("DashScope ASR SDK missing commit method")
|
|
await asyncio.to_thread(commit_fn)
|
|
self._utterance_active = False
|
|
|
|
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
|
|
if not self._audio_sent_in_utterance:
|
|
return ""
|
|
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
|
|
try:
|
|
text = await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec)
|
|
return str(text or "").strip()
|
|
except asyncio.TimeoutError:
|
|
logger.debug("DashScope ASR final timeout ({}ms), fallback to last interim", timeout_ms)
|
|
return str(self._last_interim_text or "").strip()
|
|
|
|
def clear_utterance(self) -> None:
|
|
self._utterance_active = False
|
|
self._audio_sent_in_utterance = False
|
|
self._last_interim_text = ""
|
|
self._last_error = None
|
|
self._drain_queue(self._final_queue)
|
|
|
|
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
|
while self._running:
|
|
try:
|
|
result = await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
|
yield result
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
async def _configure_session(self) -> None:
|
|
if not self._client:
|
|
raise RuntimeError("DashScope ASR client is not initialized")
|
|
|
|
text_modality: Any = "text"
|
|
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
|
|
text_modality = MultiModality.TEXT
|
|
|
|
transcription_params: Optional[Any] = None
|
|
if TranscriptionParams is not None:
|
|
try:
|
|
lang = "zh" if self.language == "auto" else self.language
|
|
transcription_params = TranscriptionParams(
|
|
language=lang,
|
|
sample_rate=self.sample_rate,
|
|
input_audio_format="pcm",
|
|
)
|
|
except Exception as exc:
|
|
logger.debug("DashScope ASR TranscriptionParams init failed: {}", exc)
|
|
transcription_params = None
|
|
|
|
update_attempts = [
|
|
{
|
|
"output_modalities": [text_modality],
|
|
"enable_turn_detection": False,
|
|
"enable_input_audio_transcription": True,
|
|
"transcription_params": transcription_params,
|
|
},
|
|
{
|
|
"output_modalities": [text_modality],
|
|
"enable_turn_detection": False,
|
|
"enable_input_audio_transcription": True,
|
|
},
|
|
{
|
|
"output_modalities": [text_modality],
|
|
},
|
|
]
|
|
|
|
update_fn = getattr(self._client, "update_session", None)
|
|
if not callable(update_fn):
|
|
raise RuntimeError("DashScope ASR SDK missing update_session method")
|
|
|
|
last_error: Optional[Exception] = None
|
|
for params in update_attempts:
|
|
if params.get("transcription_params") is None:
|
|
params = {k: v for k, v in params.items() if k != "transcription_params"}
|
|
try:
|
|
await asyncio.to_thread(update_fn, **params)
|
|
break
|
|
except TypeError as exc:
|
|
last_error = exc
|
|
continue
|
|
except Exception as exc:
|
|
last_error = exc
|
|
continue
|
|
else:
|
|
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
|
|
|
|
try:
|
|
await asyncio.wait_for(self._session_ready.wait(), timeout=6.0)
|
|
except asyncio.TimeoutError:
|
|
logger.debug("DashScope ASR session ready wait timeout; continuing")
|
|
|
|
def _on_ws_open(self) -> None:
|
|
return None
|
|
|
|
def _on_ws_close(self, code: int, msg: str) -> None:
|
|
self._last_error = f"DashScope ASR websocket closed: {code} {msg}"
|
|
logger.debug(self._last_error)
|
|
|
|
def _on_ws_error(self, message: Any) -> None:
|
|
self._last_error = str(message)
|
|
logger.error("DashScope ASR error: {}", self._last_error)
|
|
|
|
def _on_ws_event(self, message: Any) -> None:
|
|
payload = self._coerce_event(message)
|
|
event_type = str(payload.get("type") or "").strip()
|
|
if not event_type:
|
|
return
|
|
|
|
if event_type in {"session.created", "session.updated"}:
|
|
self._session_ready.set()
|
|
return
|
|
if event_type == "error" or event_type.endswith(".failed"):
|
|
err_text = self._extract_text(payload, keys=("message", "error", "details"))
|
|
self._last_error = err_text or event_type
|
|
logger.error("DashScope ASR server event error: {}", self._last_error)
|
|
return
|
|
|
|
if event_type == "conversation.item.input_audio_transcription.text":
|
|
stash_text = self._extract_text(payload, keys=("stash", "text", "transcript"))
|
|
self._emit_transcript(stash_text, is_final=False)
|
|
return
|
|
|
|
if event_type == "conversation.item.input_audio_transcription.completed":
|
|
final_text = self._extract_text(payload, keys=("transcript", "text", "stash"))
|
|
self._emit_transcript(final_text, is_final=True)
|
|
return
|
|
|
|
def _emit_transcript(self, text: str, *, is_final: bool) -> None:
|
|
normalized = str(text or "").strip()
|
|
if not normalized:
|
|
return
|
|
if not is_final and normalized == self._last_interim_text:
|
|
return
|
|
if not is_final:
|
|
self._last_interim_text = normalized
|
|
|
|
if self._loop is None:
|
|
return
|
|
try:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._publish_transcript(normalized, is_final=is_final),
|
|
self._loop,
|
|
)
|
|
except RuntimeError:
|
|
return
|
|
|
|
async def _publish_transcript(self, text: str, *, is_final: bool) -> None:
|
|
await self._transcript_queue.put(ASRResult(text=text, is_final=is_final))
|
|
if is_final:
|
|
await self._final_queue.put(text)
|
|
if self.on_transcript:
|
|
try:
|
|
await self.on_transcript(text, is_final)
|
|
except Exception as exc:
|
|
logger.warning("DashScope ASR transcript callback failed: {}", exc)
|
|
|
|
@staticmethod
|
|
def _coerce_event(message: Any) -> Dict[str, Any]:
|
|
if isinstance(message, dict):
|
|
return message
|
|
if isinstance(message, str):
|
|
try:
|
|
parsed = json.loads(message)
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
return {"type": "raw", "text": message}
|
|
return {"type": "raw", "text": str(message)}
|
|
|
|
def _extract_text(self, payload: Dict[str, Any], *, keys: tuple[str, ...]) -> str:
|
|
for key in keys:
|
|
value = payload.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
if isinstance(value, dict):
|
|
nested = self._extract_text(value, keys=keys)
|
|
if nested:
|
|
return nested
|
|
|
|
for value in payload.values():
|
|
if isinstance(value, dict):
|
|
nested = self._extract_text(value, keys=keys)
|
|
if nested:
|
|
return nested
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
|
|
while True:
|
|
try:
|
|
queue.get_nowait()
|
|
except asyncio.QueueEmpty:
|
|
break
|