Implement DashScope ASR provider and enhance ASR service architecture
- Added DashScope ASR service implementation for real-time streaming. - Updated ASR provider logic to support DashScope alongside existing providers. - Enhanced runtime metadata resolution to include DashScope as a valid ASR provider. - Modified configuration files and documentation to reflect the addition of DashScope. - Introduced tests to validate DashScope integration and ASR service behavior. - Refactored ASR service factory to accommodate new provider options and modes.
This commit is contained in:
@@ -1 +1,13 @@
|
||||
"""ASR providers."""
|
||||
|
||||
from providers.asr.buffered import BufferedASRService, MockASRService
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService
|
||||
|
||||
__all__ = [
|
||||
"BufferedASRService",
|
||||
"MockASRService",
|
||||
"DashScopeRealtimeASRService",
|
||||
"OpenAICompatibleASRService",
|
||||
"SiliconFlowASRService",
|
||||
]
|
||||
|
||||
@@ -34,6 +34,7 @@ class BufferedASRService(BaseASRService):
|
||||
language: str = "en"
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "offline"
|
||||
|
||||
self._audio_buffer: bytes = b""
|
||||
self._current_text: str = ""
|
||||
@@ -86,6 +87,23 @@ class BufferedASRService(BaseASRService):
|
||||
self._current_text = ""
|
||||
self._audio_buffer = b""
|
||||
return text
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
"""Offline compatibility method used by DuplexPipeline."""
|
||||
return self.get_and_clear_text()
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""Offline compatibility method used by DuplexPipeline."""
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
"""No-op for plain buffered ASR."""
|
||||
return None
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
"""No-op for plain buffered ASR."""
|
||||
return None
|
||||
|
||||
def get_audio_buffer(self) -> bytes:
|
||||
"""Get accumulated audio buffer."""
|
||||
@@ -103,6 +121,7 @@ class MockASRService(BaseASRService):
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "offline"
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
self._mock_texts = [
|
||||
"Hello, how are you?",
|
||||
@@ -145,3 +164,18 @@ class MockASRService(BaseASRService):
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
return ""
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
return None
|
||||
|
||||
388
engine/providers/asr/dashscope.py
Normal file
388
engine/providers/asr/dashscope.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""DashScope realtime streaming ASR service.
|
||||
|
||||
Uses Qwen-ASR-Realtime via DashScope Python SDK.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import ASRResult, BaseASRService, ServiceState
|
||||
|
||||
try:
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
|
||||
|
||||
# Some SDK builds keep TranscriptionParams under qwen_omni.omni_realtime.
|
||||
try:
|
||||
from dashscope.audio.qwen_omni import TranscriptionParams
|
||||
except ImportError:
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
|
||||
|
||||
DASHSCOPE_SDK_AVAILABLE = True
|
||||
DASHSCOPE_IMPORT_ERROR = ""
|
||||
except Exception as exc:
|
||||
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
|
||||
dashscope = None # type: ignore[assignment]
|
||||
MultiModality = None # type: ignore[assignment]
|
||||
OmniRealtimeConversation = None # type: ignore[assignment]
|
||||
TranscriptionParams = None # type: ignore[assignment]
|
||||
DASHSCOPE_SDK_AVAILABLE = False
|
||||
|
||||
class OmniRealtimeCallback: # type: ignore[no-redef]
|
||||
"""Fallback callback base when DashScope SDK is unavailable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _DashScopeASRCallback(OmniRealtimeCallback):
|
||||
"""Bridge DashScope SDK callbacks into asyncio loop-safe handlers."""
|
||||
|
||||
def __init__(self, owner: "DashScopeRealtimeASRService", loop: asyncio.AbstractEventLoop):
|
||||
super().__init__()
|
||||
self._owner = owner
|
||||
self._loop = loop
|
||||
|
||||
def _schedule(self, fn: Callable[[], None]) -> None:
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(fn)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
def on_open(self) -> None:
|
||||
self._schedule(self._owner._on_ws_open)
|
||||
|
||||
def on_close(self, code: int, msg: str) -> None:
|
||||
self._schedule(lambda: self._owner._on_ws_close(code, msg))
|
||||
|
||||
def on_event(self, message: Any) -> None:
|
||||
self._schedule(lambda: self._owner._on_ws_event(message))
|
||||
|
||||
def on_error(self, message: Any) -> None:
|
||||
self._schedule(lambda: self._owner._on_ws_error(message))
|
||||
|
||||
|
||||
class DashScopeRealtimeASRService(BaseASRService):
|
||||
"""Realtime streaming ASR implementation for DashScope Qwen-ASR-Realtime."""
|
||||
|
||||
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
DEFAULT_MODEL = "qwen3-asr-flash-realtime"
|
||||
DEFAULT_FINAL_TIMEOUT_MS = 800
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "streaming"
|
||||
self.api_key = (
|
||||
api_key
|
||||
or os.getenv("DASHSCOPE_API_KEY")
|
||||
or os.getenv("ASR_API_KEY")
|
||||
)
|
||||
self.api_url = api_url or os.getenv("DASHSCOPE_ASR_API_URL") or self.DEFAULT_WS_URL
|
||||
self.model = model or os.getenv("DASHSCOPE_ASR_MODEL") or self.DEFAULT_MODEL
|
||||
self.on_transcript = on_transcript
|
||||
|
||||
self._client: Optional[Any] = None
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._callback: Optional[_DashScopeASRCallback] = None
|
||||
|
||||
self._running = False
|
||||
self._session_ready = asyncio.Event()
|
||||
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
|
||||
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
|
||||
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._last_interim_text = ""
|
||||
self._last_error: Optional[str] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
py_exec = sys.executable
|
||||
hint = f"`{py_exec} -m pip install dashscope>=1.25.6`"
|
||||
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
||||
raise RuntimeError(
|
||||
f"dashscope SDK unavailable in interpreter {py_exec}; install with {hint}{detail}"
|
||||
)
|
||||
if not self.api_key:
|
||||
raise ValueError("DashScope ASR API key not provided. Configure agent.asr.api_key in YAML.")
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._callback = _DashScopeASRCallback(owner=self, loop=self._loop)
|
||||
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = self.api_key
|
||||
|
||||
self._client = OmniRealtimeConversation( # type: ignore[misc]
|
||||
model=self.model,
|
||||
url=self.api_url,
|
||||
callback=self._callback,
|
||||
)
|
||||
await asyncio.to_thread(self._client.connect)
|
||||
await self._configure_session()
|
||||
|
||||
self._running = True
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(
|
||||
"DashScope realtime ASR connected: model={}, sample_rate={}, language={}",
|
||||
self.model,
|
||||
self.sample_rate,
|
||||
self.language,
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._running = False
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._drain_queue(self._final_queue)
|
||||
self._drain_queue(self._transcript_queue)
|
||||
self._session_ready.clear()
|
||||
|
||||
if self._client is not None:
|
||||
close_fn = getattr(self._client, "close", None)
|
||||
if callable(close_fn):
|
||||
await asyncio.to_thread(close_fn)
|
||||
self._client = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("DashScope realtime ASR disconnected")
|
||||
|
||||
async def begin_utterance(self) -> None:
|
||||
self.clear_utterance()
|
||||
self._utterance_active = True
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
if not self._client:
|
||||
raise RuntimeError("DashScope ASR service not connected")
|
||||
if not audio:
|
||||
return
|
||||
|
||||
if not self._utterance_active:
|
||||
# Allow graceful fallback if caller sends before begin_utterance.
|
||||
self._utterance_active = True
|
||||
|
||||
audio_b64 = base64.b64encode(audio).decode("ascii")
|
||||
append_fn = getattr(self._client, "append_audio", None)
|
||||
if not callable(append_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing append_audio method")
|
||||
await asyncio.to_thread(append_fn, audio_b64)
|
||||
self._audio_sent_in_utterance = True
|
||||
|
||||
async def end_utterance(self) -> None:
|
||||
if not self._client:
|
||||
return
|
||||
if not self._utterance_active or not self._audio_sent_in_utterance:
|
||||
return
|
||||
|
||||
commit_fn = getattr(self._client, "commit", None)
|
||||
if not callable(commit_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing commit method")
|
||||
await asyncio.to_thread(commit_fn)
|
||||
self._utterance_active = False
|
||||
|
||||
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
|
||||
if not self._audio_sent_in_utterance:
|
||||
return ""
|
||||
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
|
||||
try:
|
||||
text = await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec)
|
||||
return str(text or "").strip()
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("DashScope ASR final timeout ({}ms), fallback to last interim", timeout_ms)
|
||||
return str(self._last_interim_text or "").strip()
|
||||
|
||||
def clear_utterance(self) -> None:
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._last_interim_text = ""
|
||||
self._last_error = None
|
||||
self._drain_queue(self._final_queue)
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
while self._running:
|
||||
try:
|
||||
result = await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _configure_session(self) -> None:
|
||||
if not self._client:
|
||||
raise RuntimeError("DashScope ASR client is not initialized")
|
||||
|
||||
text_modality: Any = "text"
|
||||
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
|
||||
text_modality = MultiModality.TEXT
|
||||
|
||||
transcription_params: Optional[Any] = None
|
||||
if TranscriptionParams is not None:
|
||||
try:
|
||||
lang = "zh" if self.language == "auto" else self.language
|
||||
transcription_params = TranscriptionParams(
|
||||
language=lang,
|
||||
sample_rate=self.sample_rate,
|
||||
input_audio_format="pcm",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("DashScope ASR TranscriptionParams init failed: {}", exc)
|
||||
transcription_params = None
|
||||
|
||||
update_attempts = [
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
"enable_turn_detection": False,
|
||||
"enable_input_audio_transcription": True,
|
||||
"transcription_params": transcription_params,
|
||||
},
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
"enable_turn_detection": False,
|
||||
"enable_input_audio_transcription": True,
|
||||
},
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
},
|
||||
]
|
||||
|
||||
update_fn = getattr(self._client, "update_session", None)
|
||||
if not callable(update_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing update_session method")
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for params in update_attempts:
|
||||
if params.get("transcription_params") is None:
|
||||
params = {k: v for k, v in params.items() if k != "transcription_params"}
|
||||
try:
|
||||
await asyncio.to_thread(update_fn, **params)
|
||||
break
|
||||
except TypeError as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._session_ready.wait(), timeout=6.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("DashScope ASR session ready wait timeout; continuing")
|
||||
|
||||
def _on_ws_open(self) -> None:
|
||||
return None
|
||||
|
||||
def _on_ws_close(self, code: int, msg: str) -> None:
|
||||
self._last_error = f"DashScope ASR websocket closed: {code} {msg}"
|
||||
logger.debug(self._last_error)
|
||||
|
||||
def _on_ws_error(self, message: Any) -> None:
|
||||
self._last_error = str(message)
|
||||
logger.error("DashScope ASR error: {}", self._last_error)
|
||||
|
||||
def _on_ws_event(self, message: Any) -> None:
|
||||
payload = self._coerce_event(message)
|
||||
event_type = str(payload.get("type") or "").strip()
|
||||
if not event_type:
|
||||
return
|
||||
|
||||
if event_type in {"session.created", "session.updated"}:
|
||||
self._session_ready.set()
|
||||
return
|
||||
if event_type == "error" or event_type.endswith(".failed"):
|
||||
err_text = self._extract_text(payload, keys=("message", "error", "details"))
|
||||
self._last_error = err_text or event_type
|
||||
logger.error("DashScope ASR server event error: {}", self._last_error)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.text":
|
||||
stash_text = self._extract_text(payload, keys=("stash", "text", "transcript"))
|
||||
self._emit_transcript(stash_text, is_final=False)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
final_text = self._extract_text(payload, keys=("transcript", "text", "stash"))
|
||||
self._emit_transcript(final_text, is_final=True)
|
||||
return
|
||||
|
||||
def _emit_transcript(self, text: str, *, is_final: bool) -> None:
|
||||
normalized = str(text or "").strip()
|
||||
if not normalized:
|
||||
return
|
||||
if not is_final and normalized == self._last_interim_text:
|
||||
return
|
||||
if not is_final:
|
||||
self._last_interim_text = normalized
|
||||
|
||||
if self._loop is None:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._publish_transcript(normalized, is_final=is_final),
|
||||
self._loop,
|
||||
)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
async def _publish_transcript(self, text: str, *, is_final: bool) -> None:
|
||||
await self._transcript_queue.put(ASRResult(text=text, is_final=is_final))
|
||||
if is_final:
|
||||
await self._final_queue.put(text)
|
||||
if self.on_transcript:
|
||||
try:
|
||||
await self.on_transcript(text, is_final)
|
||||
except Exception as exc:
|
||||
logger.warning("DashScope ASR transcript callback failed: {}", exc)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_event(message: Any) -> Dict[str, Any]:
|
||||
if isinstance(message, dict):
|
||||
return message
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
parsed = json.loads(message)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
return {"type": "raw", "text": message}
|
||||
return {"type": "raw", "text": str(message)}
|
||||
|
||||
def _extract_text(self, payload: Dict[str, Any], *, keys: tuple[str, ...]) -> str:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
nested = self._extract_text(value, keys=keys)
|
||||
if nested:
|
||||
return nested
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, dict):
|
||||
nested = self._extract_text(value, keys=keys)
|
||||
if nested:
|
||||
return nested
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
|
||||
while True:
|
||||
try:
|
||||
queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
@@ -71,6 +71,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
on_transcript: Callback for transcription results (text, is_final)
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "offline"
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
|
||||
|
||||
@@ -16,6 +16,7 @@ from runtime.ports import (
|
||||
TTSServiceSpec,
|
||||
)
|
||||
from providers.asr.buffered import BufferedASRService
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
from providers.tts.dashscope import DashScopeTTSService
|
||||
from providers.llm.openai import MockLLMService, OpenAILLMService
|
||||
from providers.asr.openai_compatible import OpenAICompatibleASRService
|
||||
@@ -23,6 +24,7 @@ from providers.tts.openai_compatible import OpenAICompatibleTTSService
|
||||
from providers.tts.mock import MockTTSService
|
||||
|
||||
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
_DASHSCOPE_PROVIDERS = {"dashscope"}
|
||||
_SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS}
|
||||
|
||||
|
||||
@@ -31,6 +33,8 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
|
||||
_DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
_DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime"
|
||||
_DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
_DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime"
|
||||
_DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
||||
_DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||||
|
||||
@@ -96,6 +100,16 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort:
|
||||
provider = self._normalize_provider(spec.provider)
|
||||
|
||||
if provider in _DASHSCOPE_PROVIDERS and spec.api_key:
|
||||
return DashScopeRealtimeASRService(
|
||||
api_key=spec.api_key,
|
||||
api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL,
|
||||
model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL,
|
||||
sample_rate=spec.sample_rate,
|
||||
language=spec.language,
|
||||
on_transcript=spec.on_transcript,
|
||||
)
|
||||
|
||||
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
|
||||
return OpenAICompatibleASRService(
|
||||
api_key=spec.api_key,
|
||||
|
||||
Reference in New Issue
Block a user