Implement DashScope ASR provider and enhance ASR service architecture

- Added DashScope ASR service implementation for real-time streaming.
- Updated ASR provider logic to support DashScope alongside existing providers.
- Enhanced runtime metadata resolution to include DashScope as a valid ASR provider.
- Modified configuration files and documentation to reflect the addition of DashScope.
- Introduced tests to validate DashScope integration and ASR service behavior.
- Refactored ASR service factory to accommodate new provider options and modes.
This commit is contained in:
Xin Wang
2026-03-06 11:44:39 +08:00
parent 7e0b777923
commit e11c3abb9e
19 changed files with 940 additions and 44 deletions

View File

@@ -1 +1,13 @@
"""ASR providers."""
from providers.asr.buffered import BufferedASRService, MockASRService
from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService
__all__ = [
"BufferedASRService",
"MockASRService",
"DashScopeRealtimeASRService",
"OpenAICompatibleASRService",
"SiliconFlowASRService",
]

View File

@@ -34,6 +34,7 @@ class BufferedASRService(BaseASRService):
language: str = "en"
):
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
self._audio_buffer: bytes = b""
self._current_text: str = ""
@@ -86,6 +87,23 @@ class BufferedASRService(BaseASRService):
self._current_text = ""
self._audio_buffer = b""
return text
async def get_final_transcription(self) -> str:
"""Offline compatibility method used by DuplexPipeline."""
return self.get_and_clear_text()
def clear_buffer(self) -> None:
"""Offline compatibility method used by DuplexPipeline."""
self._audio_buffer = b""
self._current_text = ""
async def start_interim_transcription(self) -> None:
"""No-op for plain buffered ASR."""
return None
async def stop_interim_transcription(self) -> None:
"""No-op for plain buffered ASR."""
return None
def get_audio_buffer(self) -> bytes:
"""Get accumulated audio buffer."""
@@ -103,6 +121,7 @@ class MockASRService(BaseASRService):
def __init__(self, sample_rate: int = 16000, language: str = "en"):
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
self._mock_texts = [
"Hello, how are you?",
@@ -145,3 +164,18 @@ class MockASRService(BaseASRService):
continue
except asyncio.CancelledError:
break
def clear_buffer(self) -> None:
return None
async def get_final_transcription(self) -> str:
return ""
def get_and_clear_text(self) -> str:
return ""
async def start_interim_transcription(self) -> None:
return None
async def stop_interim_transcription(self) -> None:
return None

View File

@@ -0,0 +1,388 @@
"""DashScope realtime streaming ASR service.
Uses Qwen-ASR-Realtime via DashScope Python SDK.
"""
from __future__ import annotations
import asyncio
import base64
import json
import os
import sys
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional
from loguru import logger
from providers.common.base import ASRResult, BaseASRService, ServiceState
try:
import dashscope
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
# Some SDK builds keep TranscriptionParams under qwen_omni.omni_realtime.
try:
from dashscope.audio.qwen_omni import TranscriptionParams
except ImportError:
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
DASHSCOPE_SDK_AVAILABLE = True
DASHSCOPE_IMPORT_ERROR = ""
except Exception as exc:
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
dashscope = None # type: ignore[assignment]
MultiModality = None # type: ignore[assignment]
OmniRealtimeConversation = None # type: ignore[assignment]
TranscriptionParams = None # type: ignore[assignment]
DASHSCOPE_SDK_AVAILABLE = False
class OmniRealtimeCallback: # type: ignore[no-redef]
"""Fallback callback base when DashScope SDK is unavailable."""
pass
class _DashScopeASRCallback(OmniRealtimeCallback):
"""Bridge DashScope SDK callbacks into asyncio loop-safe handlers."""
def __init__(self, owner: "DashScopeRealtimeASRService", loop: asyncio.AbstractEventLoop):
super().__init__()
self._owner = owner
self._loop = loop
def _schedule(self, fn: Callable[[], None]) -> None:
try:
self._loop.call_soon_threadsafe(fn)
except RuntimeError:
return
def on_open(self) -> None:
self._schedule(self._owner._on_ws_open)
def on_close(self, code: int, msg: str) -> None:
self._schedule(lambda: self._owner._on_ws_close(code, msg))
def on_event(self, message: Any) -> None:
self._schedule(lambda: self._owner._on_ws_event(message))
def on_error(self, message: Any) -> None:
self._schedule(lambda: self._owner._on_ws_error(message))
class DashScopeRealtimeASRService(BaseASRService):
"""Realtime streaming ASR implementation for DashScope Qwen-ASR-Realtime."""
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
DEFAULT_MODEL = "qwen3-asr-flash-realtime"
DEFAULT_FINAL_TIMEOUT_MS = 800
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
model: Optional[str] = None,
sample_rate: int = 16000,
language: str = "auto",
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
) -> None:
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "streaming"
self.api_key = (
api_key
or os.getenv("DASHSCOPE_API_KEY")
or os.getenv("ASR_API_KEY")
)
self.api_url = api_url or os.getenv("DASHSCOPE_ASR_API_URL") or self.DEFAULT_WS_URL
self.model = model or os.getenv("DASHSCOPE_ASR_MODEL") or self.DEFAULT_MODEL
self.on_transcript = on_transcript
self._client: Optional[Any] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._callback: Optional[_DashScopeASRCallback] = None
self._running = False
self._session_ready = asyncio.Event()
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error: Optional[str] = None
async def connect(self) -> None:
if not DASHSCOPE_SDK_AVAILABLE:
py_exec = sys.executable
hint = f"`{py_exec} -m pip install dashscope>=1.25.6`"
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
raise RuntimeError(
f"dashscope SDK unavailable in interpreter {py_exec}; install with {hint}{detail}"
)
if not self.api_key:
raise ValueError("DashScope ASR API key not provided. Configure agent.asr.api_key in YAML.")
self._loop = asyncio.get_running_loop()
self._callback = _DashScopeASRCallback(owner=self, loop=self._loop)
if dashscope is not None:
dashscope.api_key = self.api_key
self._client = OmniRealtimeConversation( # type: ignore[misc]
model=self.model,
url=self.api_url,
callback=self._callback,
)
await asyncio.to_thread(self._client.connect)
await self._configure_session()
self._running = True
self.state = ServiceState.CONNECTED
logger.info(
"DashScope realtime ASR connected: model={}, sample_rate={}, language={}",
self.model,
self.sample_rate,
self.language,
)
async def disconnect(self) -> None:
self._running = False
self._utterance_active = False
self._audio_sent_in_utterance = False
self._drain_queue(self._final_queue)
self._drain_queue(self._transcript_queue)
self._session_ready.clear()
if self._client is not None:
close_fn = getattr(self._client, "close", None)
if callable(close_fn):
await asyncio.to_thread(close_fn)
self._client = None
self.state = ServiceState.DISCONNECTED
logger.info("DashScope realtime ASR disconnected")
async def begin_utterance(self) -> None:
self.clear_utterance()
self._utterance_active = True
async def send_audio(self, audio: bytes) -> None:
if not self._client:
raise RuntimeError("DashScope ASR service not connected")
if not audio:
return
if not self._utterance_active:
# Allow graceful fallback if caller sends before begin_utterance.
self._utterance_active = True
audio_b64 = base64.b64encode(audio).decode("ascii")
append_fn = getattr(self._client, "append_audio", None)
if not callable(append_fn):
raise RuntimeError("DashScope ASR SDK missing append_audio method")
await asyncio.to_thread(append_fn, audio_b64)
self._audio_sent_in_utterance = True
async def end_utterance(self) -> None:
if not self._client:
return
if not self._utterance_active or not self._audio_sent_in_utterance:
return
commit_fn = getattr(self._client, "commit", None)
if not callable(commit_fn):
raise RuntimeError("DashScope ASR SDK missing commit method")
await asyncio.to_thread(commit_fn)
self._utterance_active = False
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
if not self._audio_sent_in_utterance:
return ""
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
try:
text = await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec)
return str(text or "").strip()
except asyncio.TimeoutError:
logger.debug("DashScope ASR final timeout ({}ms), fallback to last interim", timeout_ms)
return str(self._last_interim_text or "").strip()
def clear_utterance(self) -> None:
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error = None
self._drain_queue(self._final_queue)
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
while self._running:
try:
result = await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
yield result
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
async def _configure_session(self) -> None:
if not self._client:
raise RuntimeError("DashScope ASR client is not initialized")
text_modality: Any = "text"
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
text_modality = MultiModality.TEXT
transcription_params: Optional[Any] = None
if TranscriptionParams is not None:
try:
lang = "zh" if self.language == "auto" else self.language
transcription_params = TranscriptionParams(
language=lang,
sample_rate=self.sample_rate,
input_audio_format="pcm",
)
except Exception as exc:
logger.debug("DashScope ASR TranscriptionParams init failed: {}", exc)
transcription_params = None
update_attempts = [
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
"transcription_params": transcription_params,
},
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
},
{
"output_modalities": [text_modality],
},
]
update_fn = getattr(self._client, "update_session", None)
if not callable(update_fn):
raise RuntimeError("DashScope ASR SDK missing update_session method")
last_error: Optional[Exception] = None
for params in update_attempts:
if params.get("transcription_params") is None:
params = {k: v for k, v in params.items() if k != "transcription_params"}
try:
await asyncio.to_thread(update_fn, **params)
break
except TypeError as exc:
last_error = exc
continue
except Exception as exc:
last_error = exc
continue
else:
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
try:
await asyncio.wait_for(self._session_ready.wait(), timeout=6.0)
except asyncio.TimeoutError:
logger.debug("DashScope ASR session ready wait timeout; continuing")
def _on_ws_open(self) -> None:
return None
def _on_ws_close(self, code: int, msg: str) -> None:
self._last_error = f"DashScope ASR websocket closed: {code} {msg}"
logger.debug(self._last_error)
def _on_ws_error(self, message: Any) -> None:
self._last_error = str(message)
logger.error("DashScope ASR error: {}", self._last_error)
def _on_ws_event(self, message: Any) -> None:
payload = self._coerce_event(message)
event_type = str(payload.get("type") or "").strip()
if not event_type:
return
if event_type in {"session.created", "session.updated"}:
self._session_ready.set()
return
if event_type == "error" or event_type.endswith(".failed"):
err_text = self._extract_text(payload, keys=("message", "error", "details"))
self._last_error = err_text or event_type
logger.error("DashScope ASR server event error: {}", self._last_error)
return
if event_type == "conversation.item.input_audio_transcription.text":
stash_text = self._extract_text(payload, keys=("stash", "text", "transcript"))
self._emit_transcript(stash_text, is_final=False)
return
if event_type == "conversation.item.input_audio_transcription.completed":
final_text = self._extract_text(payload, keys=("transcript", "text", "stash"))
self._emit_transcript(final_text, is_final=True)
return
def _emit_transcript(self, text: str, *, is_final: bool) -> None:
normalized = str(text or "").strip()
if not normalized:
return
if not is_final and normalized == self._last_interim_text:
return
if not is_final:
self._last_interim_text = normalized
if self._loop is None:
return
try:
asyncio.run_coroutine_threadsafe(
self._publish_transcript(normalized, is_final=is_final),
self._loop,
)
except RuntimeError:
return
async def _publish_transcript(self, text: str, *, is_final: bool) -> None:
await self._transcript_queue.put(ASRResult(text=text, is_final=is_final))
if is_final:
await self._final_queue.put(text)
if self.on_transcript:
try:
await self.on_transcript(text, is_final)
except Exception as exc:
logger.warning("DashScope ASR transcript callback failed: {}", exc)
@staticmethod
def _coerce_event(message: Any) -> Dict[str, Any]:
if isinstance(message, dict):
return message
if isinstance(message, str):
try:
parsed = json.loads(message)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
return {"type": "raw", "text": message}
return {"type": "raw", "text": str(message)}
def _extract_text(self, payload: Dict[str, Any], *, keys: tuple[str, ...]) -> str:
for key in keys:
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, dict):
nested = self._extract_text(value, keys=keys)
if nested:
return nested
for value in payload.values():
if isinstance(value, dict):
nested = self._extract_text(value, keys=keys)
if nested:
return nested
return ""
@staticmethod
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
while True:
try:
queue.get_nowait()
except asyncio.QueueEmpty:
break

View File

@@ -71,6 +71,7 @@ class OpenAICompatibleASRService(BaseASRService):
on_transcript: Callback for transcription results (text, is_final)
"""
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
if not AIOHTTP_AVAILABLE:
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")

View File

@@ -16,6 +16,7 @@ from runtime.ports import (
TTSServiceSpec,
)
from providers.asr.buffered import BufferedASRService
from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.tts.dashscope import DashScopeTTSService
from providers.llm.openai import MockLLMService, OpenAILLMService
from providers.asr.openai_compatible import OpenAICompatibleASRService
@@ -23,6 +24,7 @@ from providers.tts.openai_compatible import OpenAICompatibleTTSService
from providers.tts.mock import MockTTSService
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
_DASHSCOPE_PROVIDERS = {"dashscope"}
_SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS}
@@ -31,6 +33,8 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
_DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
_DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime"
_DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
_DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime"
_DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
_DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
@@ -96,6 +100,16 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort:
provider = self._normalize_provider(spec.provider)
if provider in _DASHSCOPE_PROVIDERS and spec.api_key:
return DashScopeRealtimeASRService(
api_key=spec.api_key,
api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL,
model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL,
sample_rate=spec.sample_rate,
language=spec.language,
on_transcript=spec.on_transcript,
)
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
return OpenAICompatibleASRService(
api_key=spec.api_key,