Implement DashScope ASR provider and enhance ASR service architecture

- Added DashScope ASR service implementation for real-time streaming.
- Updated ASR provider logic to support DashScope alongside existing providers.
- Enhanced runtime metadata resolution to include DashScope as a valid ASR provider.
- Modified configuration files and documentation to reflect the addition of DashScope.
- Introduced tests to validate DashScope integration and ASR service behavior.
- Refactored ASR service factory to accommodate new provider options and modes.
This commit is contained in:
Xin Wang
2026-03-06 11:44:39 +08:00
parent 7e0b777923
commit e11c3abb9e
19 changed files with 940 additions and 44 deletions

View File

@@ -320,12 +320,17 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
if assistant.asr_model_id:
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
if asr:
asr_provider = "openai_compatible" if _is_openai_compatible_vendor(asr.vendor) else "buffered"
if _is_dashscope_vendor(asr.vendor):
asr_provider = "dashscope"
elif _is_openai_compatible_vendor(asr.vendor):
asr_provider = "openai_compatible"
else:
asr_provider = "buffered"
metadata["services"]["asr"] = {
"provider": asr_provider,
"model": asr.model_name or asr.name,
"apiKey": asr.api_key if asr_provider == "openai_compatible" else None,
"baseUrl": asr.base_url if asr_provider == "openai_compatible" else None,
"apiKey": asr.api_key if asr_provider in {"openai_compatible", "dashscope"} else None,
"baseUrl": asr.base_url if asr_provider in {"openai_compatible", "dashscope"} else None,
}
else:
warnings.append(f"ASR model not found: {assistant.asr_model_id}")

View File

@@ -343,6 +343,37 @@ class TestAssistantAPI:
assert tts["apiKey"] == "dashscope-key"
assert tts["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
def test_runtime_config_dashscope_asr_provider(self, client, sample_assistant_data):
"""DashScope ASR models should map to dashscope asr provider in runtime metadata."""
asr_resp = client.post("/api/asr", json={
"name": "DashScope Realtime ASR",
"vendor": "DashScope",
"language": "zh",
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
"api_key": "dashscope-asr-key",
"model_name": "qwen3-asr-flash-realtime",
"hotwords": [],
"enable_punctuation": True,
"enable_normalization": True,
"enabled": True,
})
assert asr_resp.status_code == 200
asr_payload = asr_resp.json()
sample_assistant_data.update({
"asrModelId": asr_payload["id"],
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
asr = metadata["services"]["asr"]
assert asr["provider"] == "dashscope"
assert asr["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data):
sample_assistant_data.update({
"firstTurnMode": "user_first",

View File

@@ -2,6 +2,11 @@
语音识别ASR负责将用户音频实时转写为文本供对话引擎理解。
## 模式
- `offline`:引擎本地缓冲音频后触发识别(适用于 OpenAI-compatible / SiliconFlow
- `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR
## 配置项
| 配置项 | 说明 |
@@ -17,8 +22,8 @@
- 客服场景建议开启热词并维护业务词表
- 多语言场景建议按会话入口显式指定语言
- 对延迟敏感场景优先选择流式识别模型
- 当前支持提供商:`openai_compatible``siliconflow``dashscope``buffered`(回退)
## 相关文档
- [语音配置总览](voices.md)

View File

@@ -85,7 +85,7 @@ class Settings(BaseSettings):
# ASR Configuration
asr_provider: str = Field(
default="openai_compatible",
description="ASR provider (openai_compatible, buffered, siliconflow)"
description="ASR provider (openai_compatible, buffered, siliconflow, dashscope)"
)
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
asr_model: Optional[str] = Field(default=None, description="ASR model name")

View File

@@ -35,7 +35,11 @@ agent:
speed: 1.0
asr:
# provider: buffered | openai_compatible | siliconflow
# provider: buffered | openai_compatible | siliconflow | dashscope
# dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-asr-flash-realtime
# note: dashscope uses streaming ASR mode (chunk-by-chunk).
provider: openai_compatible
api_key: you_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions

View File

@@ -32,7 +32,11 @@ agent:
speed: 1.0
asr:
# provider: buffered | openai_compatible | siliconflow
# provider: buffered | openai_compatible | siliconflow | dashscope
# dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-asr-flash-realtime
# note: dashscope uses streaming ASR mode (chunk-by-chunk).
provider: openai_compatible
api_key: your_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions

View File

@@ -20,7 +20,7 @@ This document defines the draft port set used to keep core runtime extensible.
- `runtime/ports/asr.py`
- `ASRServiceSpec`
- `ASRPort`
- optional extensions: `ASRInterimControl`, `ASRBufferControl`
- explicit mode ports: `OfflineASRPort`, `StreamingASRPort`
- `runtime/ports/service_factory.py`
- `RealtimeServiceFactory`
@@ -39,7 +39,7 @@ This document defines the draft port set used to keep core runtime extensible.
- supported providers: `dashscope`, `openai_compatible`, `openai-compatible`, `siliconflow`
- fallback: `MockTTSService`
- ASR:
- supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`
- supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope`
- fallback: `BufferedASRService`
## Notes

View File

@@ -1 +1,13 @@
"""ASR providers."""
from providers.asr.buffered import BufferedASRService, MockASRService
from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService
__all__ = [
"BufferedASRService",
"MockASRService",
"DashScopeRealtimeASRService",
"OpenAICompatibleASRService",
"SiliconFlowASRService",
]

View File

@@ -34,6 +34,7 @@ class BufferedASRService(BaseASRService):
language: str = "en"
):
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
self._audio_buffer: bytes = b""
self._current_text: str = ""
@@ -86,6 +87,23 @@ class BufferedASRService(BaseASRService):
self._current_text = ""
self._audio_buffer = b""
return text
async def get_final_transcription(self) -> str:
"""Offline compatibility method used by DuplexPipeline."""
return self.get_and_clear_text()
def clear_buffer(self) -> None:
"""Offline compatibility method used by DuplexPipeline."""
self._audio_buffer = b""
self._current_text = ""
async def start_interim_transcription(self) -> None:
"""No-op for plain buffered ASR."""
return None
async def stop_interim_transcription(self) -> None:
"""No-op for plain buffered ASR."""
return None
def get_audio_buffer(self) -> bytes:
"""Get accumulated audio buffer."""
@@ -103,6 +121,7 @@ class MockASRService(BaseASRService):
def __init__(self, sample_rate: int = 16000, language: str = "en"):
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
self._mock_texts = [
"Hello, how are you?",
@@ -145,3 +164,18 @@ class MockASRService(BaseASRService):
continue
except asyncio.CancelledError:
break
def clear_buffer(self) -> None:
return None
async def get_final_transcription(self) -> str:
return ""
def get_and_clear_text(self) -> str:
return ""
async def start_interim_transcription(self) -> None:
return None
async def stop_interim_transcription(self) -> None:
return None

View File

@@ -0,0 +1,388 @@
"""DashScope realtime streaming ASR service.
Uses Qwen-ASR-Realtime via DashScope Python SDK.
"""
from __future__ import annotations
import asyncio
import base64
import json
import os
import sys
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional
from loguru import logger
from providers.common.base import ASRResult, BaseASRService, ServiceState
try:
import dashscope
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
# Some SDK builds keep TranscriptionParams under qwen_omni.omni_realtime.
try:
from dashscope.audio.qwen_omni import TranscriptionParams
except ImportError:
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
DASHSCOPE_SDK_AVAILABLE = True
DASHSCOPE_IMPORT_ERROR = ""
except Exception as exc:
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
dashscope = None # type: ignore[assignment]
MultiModality = None # type: ignore[assignment]
OmniRealtimeConversation = None # type: ignore[assignment]
TranscriptionParams = None # type: ignore[assignment]
DASHSCOPE_SDK_AVAILABLE = False
class OmniRealtimeCallback: # type: ignore[no-redef]
"""Fallback callback base when DashScope SDK is unavailable."""
pass
class _DashScopeASRCallback(OmniRealtimeCallback):
"""Bridge DashScope SDK callbacks into asyncio loop-safe handlers."""
def __init__(self, owner: "DashScopeRealtimeASRService", loop: asyncio.AbstractEventLoop):
super().__init__()
self._owner = owner
self._loop = loop
def _schedule(self, fn: Callable[[], None]) -> None:
try:
self._loop.call_soon_threadsafe(fn)
except RuntimeError:
return
def on_open(self) -> None:
self._schedule(self._owner._on_ws_open)
def on_close(self, code: int, msg: str) -> None:
self._schedule(lambda: self._owner._on_ws_close(code, msg))
def on_event(self, message: Any) -> None:
self._schedule(lambda: self._owner._on_ws_event(message))
def on_error(self, message: Any) -> None:
self._schedule(lambda: self._owner._on_ws_error(message))
class DashScopeRealtimeASRService(BaseASRService):
"""Realtime streaming ASR implementation for DashScope Qwen-ASR-Realtime."""
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
DEFAULT_MODEL = "qwen3-asr-flash-realtime"
DEFAULT_FINAL_TIMEOUT_MS = 800
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
model: Optional[str] = None,
sample_rate: int = 16000,
language: str = "auto",
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
) -> None:
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "streaming"
self.api_key = (
api_key
or os.getenv("DASHSCOPE_API_KEY")
or os.getenv("ASR_API_KEY")
)
self.api_url = api_url or os.getenv("DASHSCOPE_ASR_API_URL") or self.DEFAULT_WS_URL
self.model = model or os.getenv("DASHSCOPE_ASR_MODEL") or self.DEFAULT_MODEL
self.on_transcript = on_transcript
self._client: Optional[Any] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._callback: Optional[_DashScopeASRCallback] = None
self._running = False
self._session_ready = asyncio.Event()
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error: Optional[str] = None
async def connect(self) -> None:
if not DASHSCOPE_SDK_AVAILABLE:
py_exec = sys.executable
hint = f"`{py_exec} -m pip install dashscope>=1.25.6`"
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
raise RuntimeError(
f"dashscope SDK unavailable in interpreter {py_exec}; install with {hint}{detail}"
)
if not self.api_key:
raise ValueError("DashScope ASR API key not provided. Configure agent.asr.api_key in YAML.")
self._loop = asyncio.get_running_loop()
self._callback = _DashScopeASRCallback(owner=self, loop=self._loop)
if dashscope is not None:
dashscope.api_key = self.api_key
self._client = OmniRealtimeConversation( # type: ignore[misc]
model=self.model,
url=self.api_url,
callback=self._callback,
)
await asyncio.to_thread(self._client.connect)
await self._configure_session()
self._running = True
self.state = ServiceState.CONNECTED
logger.info(
"DashScope realtime ASR connected: model={}, sample_rate={}, language={}",
self.model,
self.sample_rate,
self.language,
)
async def disconnect(self) -> None:
self._running = False
self._utterance_active = False
self._audio_sent_in_utterance = False
self._drain_queue(self._final_queue)
self._drain_queue(self._transcript_queue)
self._session_ready.clear()
if self._client is not None:
close_fn = getattr(self._client, "close", None)
if callable(close_fn):
await asyncio.to_thread(close_fn)
self._client = None
self.state = ServiceState.DISCONNECTED
logger.info("DashScope realtime ASR disconnected")
async def begin_utterance(self) -> None:
self.clear_utterance()
self._utterance_active = True
async def send_audio(self, audio: bytes) -> None:
if not self._client:
raise RuntimeError("DashScope ASR service not connected")
if not audio:
return
if not self._utterance_active:
# Allow graceful fallback if caller sends before begin_utterance.
self._utterance_active = True
audio_b64 = base64.b64encode(audio).decode("ascii")
append_fn = getattr(self._client, "append_audio", None)
if not callable(append_fn):
raise RuntimeError("DashScope ASR SDK missing append_audio method")
await asyncio.to_thread(append_fn, audio_b64)
self._audio_sent_in_utterance = True
async def end_utterance(self) -> None:
if not self._client:
return
if not self._utterance_active or not self._audio_sent_in_utterance:
return
commit_fn = getattr(self._client, "commit", None)
if not callable(commit_fn):
raise RuntimeError("DashScope ASR SDK missing commit method")
await asyncio.to_thread(commit_fn)
self._utterance_active = False
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
if not self._audio_sent_in_utterance:
return ""
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
try:
text = await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec)
return str(text or "").strip()
except asyncio.TimeoutError:
logger.debug("DashScope ASR final timeout ({}ms), fallback to last interim", timeout_ms)
return str(self._last_interim_text or "").strip()
def clear_utterance(self) -> None:
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error = None
self._drain_queue(self._final_queue)
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
while self._running:
try:
result = await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
yield result
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
async def _configure_session(self) -> None:
if not self._client:
raise RuntimeError("DashScope ASR client is not initialized")
text_modality: Any = "text"
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
text_modality = MultiModality.TEXT
transcription_params: Optional[Any] = None
if TranscriptionParams is not None:
try:
lang = "zh" if self.language == "auto" else self.language
transcription_params = TranscriptionParams(
language=lang,
sample_rate=self.sample_rate,
input_audio_format="pcm",
)
except Exception as exc:
logger.debug("DashScope ASR TranscriptionParams init failed: {}", exc)
transcription_params = None
update_attempts = [
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
"transcription_params": transcription_params,
},
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
},
{
"output_modalities": [text_modality],
},
]
update_fn = getattr(self._client, "update_session", None)
if not callable(update_fn):
raise RuntimeError("DashScope ASR SDK missing update_session method")
last_error: Optional[Exception] = None
for params in update_attempts:
if params.get("transcription_params") is None:
params = {k: v for k, v in params.items() if k != "transcription_params"}
try:
await asyncio.to_thread(update_fn, **params)
break
except TypeError as exc:
last_error = exc
continue
except Exception as exc:
last_error = exc
continue
else:
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
try:
await asyncio.wait_for(self._session_ready.wait(), timeout=6.0)
except asyncio.TimeoutError:
logger.debug("DashScope ASR session ready wait timeout; continuing")
def _on_ws_open(self) -> None:
return None
def _on_ws_close(self, code: int, msg: str) -> None:
self._last_error = f"DashScope ASR websocket closed: {code} {msg}"
logger.debug(self._last_error)
def _on_ws_error(self, message: Any) -> None:
self._last_error = str(message)
logger.error("DashScope ASR error: {}", self._last_error)
def _on_ws_event(self, message: Any) -> None:
payload = self._coerce_event(message)
event_type = str(payload.get("type") or "").strip()
if not event_type:
return
if event_type in {"session.created", "session.updated"}:
self._session_ready.set()
return
if event_type == "error" or event_type.endswith(".failed"):
err_text = self._extract_text(payload, keys=("message", "error", "details"))
self._last_error = err_text or event_type
logger.error("DashScope ASR server event error: {}", self._last_error)
return
if event_type == "conversation.item.input_audio_transcription.text":
stash_text = self._extract_text(payload, keys=("stash", "text", "transcript"))
self._emit_transcript(stash_text, is_final=False)
return
if event_type == "conversation.item.input_audio_transcription.completed":
final_text = self._extract_text(payload, keys=("transcript", "text", "stash"))
self._emit_transcript(final_text, is_final=True)
return
def _emit_transcript(self, text: str, *, is_final: bool) -> None:
normalized = str(text or "").strip()
if not normalized:
return
if not is_final and normalized == self._last_interim_text:
return
if not is_final:
self._last_interim_text = normalized
if self._loop is None:
return
try:
asyncio.run_coroutine_threadsafe(
self._publish_transcript(normalized, is_final=is_final),
self._loop,
)
except RuntimeError:
return
async def _publish_transcript(self, text: str, *, is_final: bool) -> None:
await self._transcript_queue.put(ASRResult(text=text, is_final=is_final))
if is_final:
await self._final_queue.put(text)
if self.on_transcript:
try:
await self.on_transcript(text, is_final)
except Exception as exc:
logger.warning("DashScope ASR transcript callback failed: {}", exc)
@staticmethod
def _coerce_event(message: Any) -> Dict[str, Any]:
if isinstance(message, dict):
return message
if isinstance(message, str):
try:
parsed = json.loads(message)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
return {"type": "raw", "text": message}
return {"type": "raw", "text": str(message)}
def _extract_text(self, payload: Dict[str, Any], *, keys: tuple[str, ...]) -> str:
for key in keys:
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, dict):
nested = self._extract_text(value, keys=keys)
if nested:
return nested
for value in payload.values():
if isinstance(value, dict):
nested = self._extract_text(value, keys=keys)
if nested:
return nested
return ""
@staticmethod
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
while True:
try:
queue.get_nowait()
except asyncio.QueueEmpty:
break

View File

@@ -71,6 +71,7 @@ class OpenAICompatibleASRService(BaseASRService):
on_transcript: Callback for transcription results (text, is_final)
"""
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
if not AIOHTTP_AVAILABLE:
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")

View File

@@ -16,6 +16,7 @@ from runtime.ports import (
TTSServiceSpec,
)
from providers.asr.buffered import BufferedASRService
from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.tts.dashscope import DashScopeTTSService
from providers.llm.openai import MockLLMService, OpenAILLMService
from providers.asr.openai_compatible import OpenAICompatibleASRService
@@ -23,6 +24,7 @@ from providers.tts.openai_compatible import OpenAICompatibleTTSService
from providers.tts.mock import MockTTSService
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
_DASHSCOPE_PROVIDERS = {"dashscope"}
_SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS}
@@ -31,6 +33,8 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
_DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
_DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime"
_DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
_DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime"
_DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
_DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
@@ -96,6 +100,16 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort:
provider = self._normalize_provider(spec.provider)
if provider in _DASHSCOPE_PROVIDERS and spec.api_key:
return DashScopeRealtimeASRService(
api_key=spec.api_key,
api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL,
model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL,
sample_rate=spec.sample_rate,
language=spec.language,
on_transcript=spec.on_transcript,
)
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
return OpenAICompatibleASRService(
api_key=spec.api_key,

View File

@@ -30,11 +30,14 @@ from providers.factory.default import DefaultRealtimeServiceFactory
from runtime.conversation import ConversationManager, ConversationState
from runtime.events import get_event_bus
from runtime.ports import (
ASRMode,
ASRPort,
ASRServiceSpec,
LLMPort,
LLMServiceSpec,
OfflineASRPort,
RealtimeServiceFactory,
StreamingASRPort,
TTSPort,
TTSServiceSpec,
)
@@ -77,6 +80,7 @@ class DuplexPipeline:
_ASR_DELTA_THROTTLE_MS = 500
_LLM_DELTA_THROTTLE_MS = 80
_ASR_CAPTURE_MAX_MS = 15000
_ASR_STREAM_FINAL_TIMEOUT_MS = 800
_OPENER_PRE_ROLL_MS = 180
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"current_time": {
@@ -317,6 +321,10 @@ class DuplexPipeline:
self.llm_service = llm_service
self.tts_service = tts_service
self.asr_service = asr_service # Will be initialized in start()
self._asr_mode: ASRMode = self._resolve_asr_mode(
settings.asr_provider,
getattr(asr_service, "mode", None),
)
self._service_factory = service_factory or DefaultRealtimeServiceFactory()
self._knowledge_searcher = knowledge_searcher
self._tool_resource_resolver = tool_resource_resolver
@@ -324,6 +332,7 @@ class DuplexPipeline:
# Track last sent transcript to avoid duplicates
self._last_sent_transcript = ""
self._latest_asr_interim_text = ""
self._pending_transcript_delta: str = ""
self._last_transcript_delta_emit_ms: float = 0.0
@@ -588,6 +597,7 @@ class DuplexPipeline:
},
"asr": {
"provider": asr_provider,
"mode": self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")),
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
@@ -787,6 +797,22 @@ class DuplexPipeline:
normalized = str(provider or "").strip().lower()
return normalized == "dashscope"
@staticmethod
def _resolve_asr_mode(provider: Any, raw_mode: Any = None) -> ASRMode:
normalized_mode = str(raw_mode or "").strip().lower()
if normalized_mode in {"offline", "streaming"}:
return normalized_mode # type: ignore[return-value]
normalized_provider = str(provider or "").strip().lower()
if normalized_provider == "dashscope":
return "streaming"
return "offline"
def _offline_asr(self) -> OfflineASRPort:
return self.asr_service # type: ignore[return-value]
def _streaming_asr(self) -> StreamingASRPort:
return self.asr_service # type: ignore[return-value]
@staticmethod
def _default_llm_base_url(provider: Any) -> Optional[str]:
normalized = str(provider or "").strip().lower()
@@ -967,11 +993,13 @@ class DuplexPipeline:
asr_model = self._runtime_asr.get("model") or settings.asr_model
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
asr_mode = self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode"))
self.asr_service = self._service_factory.create_asr_service(
ASRServiceSpec(
provider=asr_provider,
sample_rate=settings.sample_rate,
mode=asr_mode,
language="auto",
api_key=str(asr_api_key).strip() if asr_api_key else None,
api_url=str(asr_api_url).strip() if asr_api_url else None,
@@ -981,10 +1009,14 @@ class DuplexPipeline:
on_transcript=self._on_transcript_callback,
)
)
self._asr_mode = self._resolve_asr_mode(
self._runtime_asr.get("provider") or settings.asr_provider,
getattr(self.asr_service, "mode", self._runtime_asr.get("mode")),
)
await self.asr_service.connect()
logger.info("DuplexPipeline services connected")
logger.info("DuplexPipeline services connected (asr_mode={})", self._asr_mode)
if not self._outbound_task or self._outbound_task.done():
self._outbound_task = asyncio.create_task(self._outbound_loop())
@@ -1457,6 +1489,7 @@ class DuplexPipeline:
self._last_sent_transcript = text
if is_final:
self._latest_asr_interim_text = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
await self._send_event(
@@ -1472,6 +1505,7 @@ class DuplexPipeline:
logger.debug(f"Sent transcript (final): {text[:50]}...")
return
self._latest_asr_interim_text = text
self._pending_transcript_delta = text
should_emit = (
self._last_transcript_delta_emit_ms <= 0.0
@@ -1495,14 +1529,16 @@ class DuplexPipeline:
await self.conversation.start_user_turn()
self._audio_buffer = b""
self._last_sent_transcript = ""
self._latest_asr_interim_text = ""
self.eou_detector.reset()
self._asr_capture_active = False
self._asr_capture_started_ms = 0.0
self._pending_speech_audio = b""
# Clear ASR buffer. Interim starts only after ASR capture is activated.
if hasattr(self.asr_service, 'clear_buffer'):
self.asr_service.clear_buffer()
if self._asr_mode == "streaming":
self._streaming_asr().clear_utterance()
else:
self._offline_asr().clear_buffer()
logger.debug("User speech started")
@@ -1511,8 +1547,10 @@ class DuplexPipeline:
if self._asr_capture_active:
return
if hasattr(self.asr_service, 'start_interim_transcription'):
await self.asr_service.start_interim_transcription()
if self._asr_mode == "streaming":
await self._streaming_asr().begin_utterance()
else:
await self._offline_asr().start_interim_transcription()
# Prime ASR with a short pre-speech context window so the utterance
# start isn't lost while waiting for VAD to transition to Speech.
@@ -1545,24 +1583,22 @@ class DuplexPipeline:
self._pending_speech_audio = b""
return
# Add a tiny trailing silence tail to stabilize final-token decoding.
if self._asr_final_tail_bytes > 0:
final_tail = b"\x00" * self._asr_final_tail_bytes
await self.asr_service.send_audio(final_tail)
# Stop interim transcriptions
if hasattr(self.asr_service, 'stop_interim_transcription'):
await self.asr_service.stop_interim_transcription()
# Get final transcription from ASR service
user_text = ""
if hasattr(self.asr_service, 'get_final_transcription'):
# SiliconFlow ASR - get final transcription
user_text = await self.asr_service.get_final_transcription()
elif hasattr(self.asr_service, 'get_and_clear_text'):
# Buffered ASR - get accumulated text
user_text = self.asr_service.get_and_clear_text()
if self._asr_mode == "streaming":
streaming_asr = self._streaming_asr()
await streaming_asr.end_utterance()
user_text = await streaming_asr.wait_for_final_transcription(
timeout_ms=self._ASR_STREAM_FINAL_TIMEOUT_MS
)
if not user_text.strip():
user_text = self._latest_asr_interim_text
else:
# Add a tiny trailing silence tail to stabilize final-token decoding.
if self._asr_final_tail_bytes > 0:
final_tail = b"\x00" * self._asr_final_tail_bytes
await self.asr_service.send_audio(final_tail)
await self._offline_asr().stop_interim_transcription()
user_text = await self._offline_asr().get_final_transcription()
# Skip if no meaningful text
if not user_text or not user_text.strip():
@@ -1570,6 +1606,7 @@ class DuplexPipeline:
# Reset for next utterance
self._audio_buffer = b""
self._last_sent_transcript = ""
self._latest_asr_interim_text = ""
self._asr_capture_active = False
self._asr_capture_started_ms = 0.0
self._pending_speech_audio = b""
@@ -1594,6 +1631,7 @@ class DuplexPipeline:
# Clear buffers
self._audio_buffer = b""
self._last_sent_transcript = ""
self._latest_asr_interim_text = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
self._asr_capture_active = False

View File

@@ -1,6 +1,12 @@
"""Port interfaces for runtime integration boundaries."""
from runtime.ports.asr import ASRBufferControl, ASRInterimControl, ASRPort, ASRServiceSpec
from runtime.ports.asr import (
ASRMode,
ASRPort,
ASRServiceSpec,
OfflineASRPort,
StreamingASRPort,
)
from runtime.ports.control_plane import (
AssistantRuntimeConfigProvider,
ControlPlaneGateway,
@@ -13,10 +19,11 @@ from runtime.ports.service_factory import RealtimeServiceFactory
from runtime.ports.tts import TTSPort, TTSServiceSpec
__all__ = [
"ASRMode",
"ASRPort",
"ASRServiceSpec",
"ASRInterimControl",
"ASRBufferControl",
"OfflineASRPort",
"StreamingASRPort",
"AssistantRuntimeConfigProvider",
"ControlPlaneGateway",
"ConversationHistoryStore",

View File

@@ -3,11 +3,12 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import AsyncIterator, Awaitable, Callable, Optional, Protocol
from typing import AsyncIterator, Awaitable, Callable, Literal, Optional, Protocol
from providers.common.base import ASRResult
TranscriptCallback = Callable[[str, bool], Awaitable[None]]
ASRMode = Literal["offline", "streaming"]
@dataclass(frozen=True)
@@ -16,6 +17,7 @@ class ASRServiceSpec:
provider: str
sample_rate: int
mode: Optional[ASRMode] = None
language: str = "auto"
api_key: Optional[str] = None
api_url: Optional[str] = None
@@ -28,6 +30,8 @@ class ASRServiceSpec:
class ASRPort(Protocol):
"""Port for speech recognition providers."""
mode: ASRMode
async def connect(self) -> None:
"""Establish connection to ASR provider."""
@@ -41,18 +45,16 @@ class ASRPort(Protocol):
"""Stream partial/final recognition results."""
class ASRInterimControl(Protocol):
"""Optional extension for explicit interim transcription control."""
class OfflineASRPort(ASRPort, Protocol):
"""Port for offline/buffered ASR providers."""
mode: Literal["offline"]
async def start_interim_transcription(self) -> None:
"""Start interim transcription loop if supported."""
"""Start interim transcription loop."""
async def stop_interim_transcription(self) -> None:
"""Stop interim transcription loop if supported."""
class ASRBufferControl(Protocol):
"""Optional extension for explicit ASR buffer lifecycle control."""
"""Stop interim transcription loop."""
def clear_buffer(self) -> None:
"""Clear provider-side ASR buffer."""
@@ -62,3 +64,21 @@ class ASRBufferControl(Protocol):
def get_and_clear_text(self) -> str:
"""Return buffered text and clear internal state."""
class StreamingASRPort(ASRPort, Protocol):
"""Port for streaming ASR providers."""
mode: Literal["streaming"]
async def begin_utterance(self) -> None:
"""Start a new utterance stream."""
async def end_utterance(self) -> None:
"""Signal end of current utterance stream."""
async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str:
"""Wait for final transcript after utterance end."""
def clear_utterance(self) -> None:
"""Reset utterance-local state."""

View File

@@ -0,0 +1,46 @@
from providers.asr.buffered import BufferedASRService
from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.openai_compatible import OpenAICompatibleASRService
from providers.factory.default import DefaultRealtimeServiceFactory
from runtime.ports import ASRServiceSpec
def test_create_asr_service_dashscope_returns_streaming_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service(
ASRServiceSpec(
provider="dashscope",
mode="streaming",
sample_rate=16000,
api_key="test-key",
model="qwen3-asr-flash-realtime",
)
)
assert isinstance(service, DashScopeRealtimeASRService)
assert service.mode == "streaming"
def test_create_asr_service_openai_compatible_returns_offline_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service(
ASRServiceSpec(
provider="openai_compatible",
sample_rate=16000,
api_key="test-key",
model="FunAudioLLM/SenseVoiceSmall",
)
)
assert isinstance(service, OpenAICompatibleASRService)
assert service.mode == "offline"
def test_create_asr_service_fallback_buffered_for_unsupported_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service(
ASRServiceSpec(
provider="unknown_provider",
sample_rate=16000,
)
)
assert isinstance(service, BufferedASRService)
assert service.mode == "offline"

View File

@@ -0,0 +1,67 @@
import asyncio
import pytest
from providers.asr.dashscope import DashScopeRealtimeASRService
@pytest.mark.asyncio
async def test_dashscope_asr_interim_event_emits_interim_transcript():
received = []
async def _on_transcript(text: str, is_final: bool) -> None:
received.append((text, is_final))
service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript)
service._loop = asyncio.get_running_loop()
service._running = True
service._on_ws_event(
{
"type": "conversation.item.input_audio_transcription.text",
"stash": "你好世界",
}
)
await asyncio.sleep(0.05)
result = service._transcript_queue.get_nowait()
assert result.text == "你好世界"
assert result.is_final is False
assert received == [("你好世界", False)]
@pytest.mark.asyncio
async def test_dashscope_asr_final_event_emits_final_transcript_and_final_queue():
received = []
async def _on_transcript(text: str, is_final: bool) -> None:
received.append((text, is_final))
service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript)
service._loop = asyncio.get_running_loop()
service._running = True
service._audio_sent_in_utterance = True
service._on_ws_event(
{
"type": "conversation.item.input_audio_transcription.completed",
"transcript": "最终识别结果",
}
)
await asyncio.sleep(0.05)
result = service._transcript_queue.get_nowait()
assert result.text == "最终识别结果"
assert result.is_final is True
assert service._final_queue.get_nowait() == "最终识别结果"
assert received == [("最终识别结果", True)]
@pytest.mark.asyncio
async def test_dashscope_wait_for_final_falls_back_to_latest_interim_on_timeout():
service = DashScopeRealtimeASRService(api_key="test-key")
service._audio_sent_in_utterance = True
service._last_interim_text = "部分结果"
text = await service.wait_for_final_transcription(timeout_ms=10)
assert text == "部分结果"

View File

@@ -0,0 +1,196 @@
import asyncio
from typing import Any, Dict, List
import pytest
from runtime.pipeline.duplex import DuplexPipeline
class _DummySileroVAD:
def __init__(self, *args, **kwargs):
pass
def process_audio(self, _pcm: bytes) -> float:
return 0.0
class _DummyVADProcessor:
def __init__(self, *args, **kwargs):
pass
def process(self, _speech_prob: float):
return "Silence", 0.0
class _DummyEouDetector:
def __init__(self, *args, **kwargs):
self.is_speaking = True
def process(self, _vad_status: str, force_eligible: bool = False) -> bool:
_ = force_eligible
return False
def reset(self) -> None:
self.is_speaking = False
class _FakeTransport:
async def send_event(self, _event: Dict[str, Any]) -> None:
return None
async def send_audio(self, _audio: bytes) -> None:
return None
class _FakeStreamingASR:
mode = "streaming"
def __init__(self):
self.begin_calls = 0
self.end_calls = 0
self.wait_calls = 0
self.sent_audio: List[bytes] = []
self.wait_text = ""
async def connect(self) -> None:
return None
async def disconnect(self) -> None:
return None
async def send_audio(self, audio: bytes) -> None:
self.sent_audio.append(audio)
async def receive_transcripts(self):
if False:
yield None
async def begin_utterance(self) -> None:
self.begin_calls += 1
async def end_utterance(self) -> None:
self.end_calls += 1
async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str:
_ = timeout_ms
self.wait_calls += 1
return self.wait_text
def clear_utterance(self) -> None:
return None
class _FakeOfflineASR:
mode = "offline"
def __init__(self):
self.start_interim_calls = 0
self.stop_interim_calls = 0
self.sent_audio: List[bytes] = []
self.final_text = "offline final"
async def connect(self) -> None:
return None
async def disconnect(self) -> None:
return None
async def send_audio(self, audio: bytes) -> None:
self.sent_audio.append(audio)
async def receive_transcripts(self):
if False:
yield None
async def start_interim_transcription(self) -> None:
self.start_interim_calls += 1
async def stop_interim_transcription(self) -> None:
self.stop_interim_calls += 1
async def get_final_transcription(self) -> str:
return self.final_text
def clear_buffer(self) -> None:
return None
def get_and_clear_text(self) -> str:
return self.final_text
def _build_pipeline(monkeypatch, asr_service):
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
monkeypatch.setattr("runtime.pipeline.duplex.EouDetector", _DummyEouDetector)
return DuplexPipeline(
transport=_FakeTransport(),
session_id="asr_mode_test",
asr_service=asr_service,
)
@pytest.mark.asyncio
async def test_start_asr_capture_uses_streaming_begin(monkeypatch):
asr = _FakeStreamingASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "streaming"
pipeline._pending_speech_audio = b"\x00" * 320
pipeline._pre_speech_buffer = b"\x00" * 640
await pipeline._start_asr_capture()
assert asr.begin_calls == 1
assert asr.sent_audio
assert pipeline._asr_capture_active is True
@pytest.mark.asyncio
async def test_start_asr_capture_uses_offline_interim_control(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._pending_speech_audio = b"\x00" * 320
pipeline._pre_speech_buffer = b"\x00" * 640
await pipeline._start_asr_capture()
assert asr.start_interim_calls == 1
assert asr.sent_audio
assert pipeline._asr_capture_active is True
@pytest.mark.asyncio
async def test_streaming_eou_falls_back_to_latest_interim(monkeypatch):
asr = _FakeStreamingASR()
asr.wait_text = ""
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "streaming"
pipeline._asr_capture_active = True
pipeline._latest_asr_interim_text = "fallback interim text"
await pipeline.conversation.start_user_turn()
captured_events = []
captured_turns = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
_ = priority
captured_events.append(event)
async def _noop_stop_current_speech() -> None:
return None
async def _capture_turn(user_text: str, *args, **kwargs) -> None:
_ = (args, kwargs)
captured_turns.append(user_text)
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
monkeypatch.setattr(pipeline, "_stop_current_speech", _noop_stop_current_speech)
monkeypatch.setattr(pipeline, "_handle_turn", _capture_turn)
await pipeline._on_end_of_utterance()
await asyncio.sleep(0.05)
assert asr.end_calls == 1
assert asr.wait_calls == 1
assert captured_turns == ["fallback interim text"]
assert any(event.get("type") == "transcript.final" for event in captured_events)

View File

@@ -52,9 +52,33 @@ class _FakeTTS:
class _FakeASR:
mode = "offline"
async def connect(self) -> None:
return None
async def disconnect(self) -> None:
return None
async def send_audio(self, _audio: bytes) -> None:
return None
async def receive_transcripts(self):
if False:
yield None
def clear_buffer(self) -> None:
return None
async def start_interim_transcription(self) -> None:
return None
async def stop_interim_transcription(self) -> None:
return None
async def get_final_transcription(self) -> str:
return ""
class _FakeLLM:
def __init__(self, rounds: List[List[LLMStreamEvent]]):