From e11c3abb9efcdbd3e9e5e09f4e55f11aa8565ff7 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Fri, 6 Mar 2026 11:44:39 +0800 Subject: [PATCH] Implement DashScope ASR provider and enhance ASR service architecture - Added DashScope ASR service implementation for real-time streaming. - Updated ASR provider logic to support DashScope alongside existing providers. - Enhanced runtime metadata resolution to include DashScope as a valid ASR provider. - Modified configuration files and documentation to reflect the addition of DashScope. - Introduced tests to validate DashScope integration and ASR service behavior. - Refactored ASR service factory to accommodate new provider options and modes. --- api/app/routers/assistants.py | 11 +- api/tests/test_assistants.py | 31 ++ docs/content/customization/asr.md | 7 +- engine/app/config.py | 2 +- engine/config/agents/example.yaml | 6 +- engine/config/agents/tools.yaml | 6 +- engine/docs/extension_ports.md | 4 +- engine/providers/asr/__init__.py | 12 + engine/providers/asr/buffered.py | 34 ++ engine/providers/asr/dashscope.py | 388 ++++++++++++++++++++ engine/providers/asr/openai_compatible.py | 1 + engine/providers/factory/default.py | 14 + engine/runtime/pipeline/duplex.py | 84 +++-- engine/runtime/ports/__init__.py | 13 +- engine/runtime/ports/asr.py | 38 +- engine/tests/test_asr_factory_modes.py | 46 +++ engine/tests/test_dashscope_asr_provider.py | 67 ++++ engine/tests/test_duplex_asr_modes.py | 196 ++++++++++ engine/tests/test_tool_call_flow.py | 24 ++ 19 files changed, 940 insertions(+), 44 deletions(-) create mode 100644 engine/providers/asr/dashscope.py create mode 100644 engine/tests/test_asr_factory_modes.py create mode 100644 engine/tests/test_dashscope_asr_provider.py create mode 100644 engine/tests/test_duplex_asr_modes.py diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index bf43303..b63d0ee 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -320,12 +320,17 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s if assistant.asr_model_id: asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first() if asr: - asr_provider = "openai_compatible" if _is_openai_compatible_vendor(asr.vendor) else "buffered" + if _is_dashscope_vendor(asr.vendor): + asr_provider = "dashscope" + elif _is_openai_compatible_vendor(asr.vendor): + asr_provider = "openai_compatible" + else: + asr_provider = "buffered" metadata["services"]["asr"] = { "provider": asr_provider, "model": asr.model_name or asr.name, - "apiKey": asr.api_key if asr_provider == "openai_compatible" else None, - "baseUrl": asr.base_url if asr_provider == "openai_compatible" else None, + "apiKey": asr.api_key if asr_provider in {"openai_compatible", "dashscope"} else None, + "baseUrl": asr.base_url if asr_provider in {"openai_compatible", "dashscope"} else None, } else: warnings.append(f"ASR model not found: {assistant.asr_model_id}") diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index 0d880ef..3828688 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -343,6 +343,37 @@ class TestAssistantAPI: assert tts["apiKey"] == "dashscope-key" assert tts["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + def test_runtime_config_dashscope_asr_provider(self, client, sample_assistant_data): + """DashScope ASR models should map to dashscope asr provider in runtime metadata.""" + asr_resp = client.post("/api/asr", json={ + "name": "DashScope Realtime ASR", + "vendor": "DashScope", + "language": "zh", + "base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime", + "api_key": "dashscope-asr-key", + "model_name": "qwen3-asr-flash-realtime", + "hotwords": [], + "enable_punctuation": True, + "enable_normalization": True, + "enabled": True, + }) + assert asr_resp.status_code == 200 + asr_payload = asr_resp.json() + + sample_assistant_data.update({ + "asrModelId": asr_payload["id"], + }) + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + asr = metadata["services"]["asr"] + assert asr["provider"] == "dashscope" + assert asr["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data): sample_assistant_data.update({ "firstTurnMode": "user_first", diff --git a/docs/content/customization/asr.md b/docs/content/customization/asr.md index 2251804..8d73889 100644 --- a/docs/content/customization/asr.md +++ b/docs/content/customization/asr.md @@ -2,6 +2,11 @@ 语音识别(ASR)负责将用户音频实时转写为文本,供对话引擎理解。 +## 模式 + +- `offline`:引擎本地缓冲音频后触发识别(适用于 OpenAI-compatible / SiliconFlow)。 +- `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR)。 + ## 配置项 | 配置项 | 说明 | @@ -17,8 +22,8 @@ - 客服场景建议开启热词并维护业务词表 - 多语言场景建议按会话入口显式指定语言 - 对延迟敏感场景优先选择流式识别模型 +- 当前支持提供商:`openai_compatible`、`siliconflow`、`dashscope`、`buffered`(回退) ## 相关文档 - [语音配置总览](voices.md) - diff --git a/engine/app/config.py b/engine/app/config.py index 233ba75..62364d1 100644 --- a/engine/app/config.py +++ b/engine/app/config.py @@ -85,7 +85,7 @@ class Settings(BaseSettings): # ASR Configuration asr_provider: str = Field( default="openai_compatible", - description="ASR provider (openai_compatible, buffered, siliconflow)" + description="ASR provider (openai_compatible, buffered, siliconflow, dashscope)" ) asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL") asr_model: Optional[str] = Field(default=None, description="ASR model name") diff --git a/engine/config/agents/example.yaml b/engine/config/agents/example.yaml index 70f4933..2e9f157 100644 --- a/engine/config/agents/example.yaml +++ b/engine/config/agents/example.yaml @@ -35,7 +35,11 @@ agent: speed: 1.0 asr: - # provider: buffered | openai_compatible | siliconflow + # provider: buffered | openai_compatible | siliconflow | dashscope + # dashscope defaults (if omitted): + # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + # model: qwen3-asr-flash-realtime + # note: dashscope uses streaming ASR mode (chunk-by-chunk). provider: openai_compatible api_key: you_asr_api_key api_url: https://api.siliconflow.cn/v1/audio/transcriptions diff --git a/engine/config/agents/tools.yaml b/engine/config/agents/tools.yaml index e2968bb..8657080 100644 --- a/engine/config/agents/tools.yaml +++ b/engine/config/agents/tools.yaml @@ -32,7 +32,11 @@ agent: speed: 1.0 asr: - # provider: buffered | openai_compatible | siliconflow + # provider: buffered | openai_compatible | siliconflow | dashscope + # dashscope defaults (if omitted): + # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + # model: qwen3-asr-flash-realtime + # note: dashscope uses streaming ASR mode (chunk-by-chunk). provider: openai_compatible api_key: your_asr_api_key api_url: https://api.siliconflow.cn/v1/audio/transcriptions diff --git a/engine/docs/extension_ports.md b/engine/docs/extension_ports.md index 8566194..36e2aac 100644 --- a/engine/docs/extension_ports.md +++ b/engine/docs/extension_ports.md @@ -20,7 +20,7 @@ This document defines the draft port set used to keep core runtime extensible. - `runtime/ports/asr.py` - `ASRServiceSpec` - `ASRPort` - - optional extensions: `ASRInterimControl`, `ASRBufferControl` + - explicit mode ports: `OfflineASRPort`, `StreamingASRPort` - `runtime/ports/service_factory.py` - `RealtimeServiceFactory` @@ -39,7 +39,7 @@ This document defines the draft port set used to keep core runtime extensible. - supported providers: `dashscope`, `openai_compatible`, `openai-compatible`, `siliconflow` - fallback: `MockTTSService` - ASR: - - supported providers: `openai_compatible`, `openai-compatible`, `siliconflow` + - supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope` - fallback: `BufferedASRService` ## Notes diff --git a/engine/providers/asr/__init__.py b/engine/providers/asr/__init__.py index 2efe6a9..5e659be 100644 --- a/engine/providers/asr/__init__.py +++ b/engine/providers/asr/__init__.py @@ -1 +1,13 @@ """ASR providers.""" + +from providers.asr.buffered import BufferedASRService, MockASRService +from providers.asr.dashscope import DashScopeRealtimeASRService +from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService + +__all__ = [ + "BufferedASRService", + "MockASRService", + "DashScopeRealtimeASRService", + "OpenAICompatibleASRService", + "SiliconFlowASRService", +] diff --git a/engine/providers/asr/buffered.py b/engine/providers/asr/buffered.py index ce1a248..624963c 100644 --- a/engine/providers/asr/buffered.py +++ b/engine/providers/asr/buffered.py @@ -34,6 +34,7 @@ class BufferedASRService(BaseASRService): language: str = "en" ): super().__init__(sample_rate=sample_rate, language=language) + self.mode = "offline" self._audio_buffer: bytes = b"" self._current_text: str = "" @@ -86,6 +87,23 @@ class BufferedASRService(BaseASRService): self._current_text = "" self._audio_buffer = b"" return text + + async def get_final_transcription(self) -> str: + """Offline compatibility method used by DuplexPipeline.""" + return self.get_and_clear_text() + + def clear_buffer(self) -> None: + """Offline compatibility method used by DuplexPipeline.""" + self._audio_buffer = b"" + self._current_text = "" + + async def start_interim_transcription(self) -> None: + """No-op for plain buffered ASR.""" + return None + + async def stop_interim_transcription(self) -> None: + """No-op for plain buffered ASR.""" + return None def get_audio_buffer(self) -> bytes: """Get accumulated audio buffer.""" @@ -103,6 +121,7 @@ class MockASRService(BaseASRService): def __init__(self, sample_rate: int = 16000, language: str = "en"): super().__init__(sample_rate=sample_rate, language=language) + self.mode = "offline" self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue() self._mock_texts = [ "Hello, how are you?", @@ -145,3 +164,18 @@ class MockASRService(BaseASRService): continue except asyncio.CancelledError: break + + def clear_buffer(self) -> None: + return None + + async def get_final_transcription(self) -> str: + return "" + + def get_and_clear_text(self) -> str: + return "" + + async def start_interim_transcription(self) -> None: + return None + + async def stop_interim_transcription(self) -> None: + return None diff --git a/engine/providers/asr/dashscope.py b/engine/providers/asr/dashscope.py new file mode 100644 index 0000000..bed4ede --- /dev/null +++ b/engine/providers/asr/dashscope.py @@ -0,0 +1,388 @@ +"""DashScope realtime streaming ASR service. + +Uses Qwen-ASR-Realtime via DashScope Python SDK. +""" + +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import sys +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional + +from loguru import logger + +from providers.common.base import ASRResult, BaseASRService, ServiceState + +try: + import dashscope + from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation + + # Some SDK builds keep TranscriptionParams under qwen_omni.omni_realtime. + try: + from dashscope.audio.qwen_omni import TranscriptionParams + except ImportError: + from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams + + DASHSCOPE_SDK_AVAILABLE = True + DASHSCOPE_IMPORT_ERROR = "" +except Exception as exc: + DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}" + dashscope = None # type: ignore[assignment] + MultiModality = None # type: ignore[assignment] + OmniRealtimeConversation = None # type: ignore[assignment] + TranscriptionParams = None # type: ignore[assignment] + DASHSCOPE_SDK_AVAILABLE = False + + class OmniRealtimeCallback: # type: ignore[no-redef] + """Fallback callback base when DashScope SDK is unavailable.""" + + pass + + +class _DashScopeASRCallback(OmniRealtimeCallback): + """Bridge DashScope SDK callbacks into asyncio loop-safe handlers.""" + + def __init__(self, owner: "DashScopeRealtimeASRService", loop: asyncio.AbstractEventLoop): + super().__init__() + self._owner = owner + self._loop = loop + + def _schedule(self, fn: Callable[[], None]) -> None: + try: + self._loop.call_soon_threadsafe(fn) + except RuntimeError: + return + + def on_open(self) -> None: + self._schedule(self._owner._on_ws_open) + + def on_close(self, code: int, msg: str) -> None: + self._schedule(lambda: self._owner._on_ws_close(code, msg)) + + def on_event(self, message: Any) -> None: + self._schedule(lambda: self._owner._on_ws_event(message)) + + def on_error(self, message: Any) -> None: + self._schedule(lambda: self._owner._on_ws_error(message)) + + +class DashScopeRealtimeASRService(BaseASRService): + """Realtime streaming ASR implementation for DashScope Qwen-ASR-Realtime.""" + + DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + DEFAULT_MODEL = "qwen3-asr-flash-realtime" + DEFAULT_FINAL_TIMEOUT_MS = 800 + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + model: Optional[str] = None, + sample_rate: int = 16000, + language: str = "auto", + on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None, + ) -> None: + super().__init__(sample_rate=sample_rate, language=language) + self.mode = "streaming" + self.api_key = ( + api_key + or os.getenv("DASHSCOPE_API_KEY") + or os.getenv("ASR_API_KEY") + ) + self.api_url = api_url or os.getenv("DASHSCOPE_ASR_API_URL") or self.DEFAULT_WS_URL + self.model = model or os.getenv("DASHSCOPE_ASR_MODEL") or self.DEFAULT_MODEL + self.on_transcript = on_transcript + + self._client: Optional[Any] = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._callback: Optional[_DashScopeASRCallback] = None + + self._running = False + self._session_ready = asyncio.Event() + self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue() + self._final_queue: "asyncio.Queue[str]" = asyncio.Queue() + + self._utterance_active = False + self._audio_sent_in_utterance = False + self._last_interim_text = "" + self._last_error: Optional[str] = None + + async def connect(self) -> None: + if not DASHSCOPE_SDK_AVAILABLE: + py_exec = sys.executable + hint = f"`{py_exec} -m pip install dashscope>=1.25.6`" + detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else "" + raise RuntimeError( + f"dashscope SDK unavailable in interpreter {py_exec}; install with {hint}{detail}" + ) + if not self.api_key: + raise ValueError("DashScope ASR API key not provided. Configure agent.asr.api_key in YAML.") + + self._loop = asyncio.get_running_loop() + self._callback = _DashScopeASRCallback(owner=self, loop=self._loop) + + if dashscope is not None: + dashscope.api_key = self.api_key + + self._client = OmniRealtimeConversation( # type: ignore[misc] + model=self.model, + url=self.api_url, + callback=self._callback, + ) + await asyncio.to_thread(self._client.connect) + await self._configure_session() + + self._running = True + self.state = ServiceState.CONNECTED + logger.info( + "DashScope realtime ASR connected: model={}, sample_rate={}, language={}", + self.model, + self.sample_rate, + self.language, + ) + + async def disconnect(self) -> None: + self._running = False + self._utterance_active = False + self._audio_sent_in_utterance = False + self._drain_queue(self._final_queue) + self._drain_queue(self._transcript_queue) + self._session_ready.clear() + + if self._client is not None: + close_fn = getattr(self._client, "close", None) + if callable(close_fn): + await asyncio.to_thread(close_fn) + self._client = None + self.state = ServiceState.DISCONNECTED + logger.info("DashScope realtime ASR disconnected") + + async def begin_utterance(self) -> None: + self.clear_utterance() + self._utterance_active = True + + async def send_audio(self, audio: bytes) -> None: + if not self._client: + raise RuntimeError("DashScope ASR service not connected") + if not audio: + return + + if not self._utterance_active: + # Allow graceful fallback if caller sends before begin_utterance. + self._utterance_active = True + + audio_b64 = base64.b64encode(audio).decode("ascii") + append_fn = getattr(self._client, "append_audio", None) + if not callable(append_fn): + raise RuntimeError("DashScope ASR SDK missing append_audio method") + await asyncio.to_thread(append_fn, audio_b64) + self._audio_sent_in_utterance = True + + async def end_utterance(self) -> None: + if not self._client: + return + if not self._utterance_active or not self._audio_sent_in_utterance: + return + + commit_fn = getattr(self._client, "commit", None) + if not callable(commit_fn): + raise RuntimeError("DashScope ASR SDK missing commit method") + await asyncio.to_thread(commit_fn) + self._utterance_active = False + + async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str: + if not self._audio_sent_in_utterance: + return "" + timeout_sec = max(0.05, float(timeout_ms) / 1000.0) + try: + text = await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) + return str(text or "").strip() + except asyncio.TimeoutError: + logger.debug("DashScope ASR final timeout ({}ms), fallback to last interim", timeout_ms) + return str(self._last_interim_text or "").strip() + + def clear_utterance(self) -> None: + self._utterance_active = False + self._audio_sent_in_utterance = False + self._last_interim_text = "" + self._last_error = None + self._drain_queue(self._final_queue) + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + while self._running: + try: + result = await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1) + yield result + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + async def _configure_session(self) -> None: + if not self._client: + raise RuntimeError("DashScope ASR client is not initialized") + + text_modality: Any = "text" + if MultiModality is not None and hasattr(MultiModality, "TEXT"): + text_modality = MultiModality.TEXT + + transcription_params: Optional[Any] = None + if TranscriptionParams is not None: + try: + lang = "zh" if self.language == "auto" else self.language + transcription_params = TranscriptionParams( + language=lang, + sample_rate=self.sample_rate, + input_audio_format="pcm", + ) + except Exception as exc: + logger.debug("DashScope ASR TranscriptionParams init failed: {}", exc) + transcription_params = None + + update_attempts = [ + { + "output_modalities": [text_modality], + "enable_turn_detection": False, + "enable_input_audio_transcription": True, + "transcription_params": transcription_params, + }, + { + "output_modalities": [text_modality], + "enable_turn_detection": False, + "enable_input_audio_transcription": True, + }, + { + "output_modalities": [text_modality], + }, + ] + + update_fn = getattr(self._client, "update_session", None) + if not callable(update_fn): + raise RuntimeError("DashScope ASR SDK missing update_session method") + + last_error: Optional[Exception] = None + for params in update_attempts: + if params.get("transcription_params") is None: + params = {k: v for k, v in params.items() if k != "transcription_params"} + try: + await asyncio.to_thread(update_fn, **params) + break + except TypeError as exc: + last_error = exc + continue + except Exception as exc: + last_error = exc + continue + else: + raise RuntimeError(f"DashScope ASR session.update failed: {last_error}") + + try: + await asyncio.wait_for(self._session_ready.wait(), timeout=6.0) + except asyncio.TimeoutError: + logger.debug("DashScope ASR session ready wait timeout; continuing") + + def _on_ws_open(self) -> None: + return None + + def _on_ws_close(self, code: int, msg: str) -> None: + self._last_error = f"DashScope ASR websocket closed: {code} {msg}" + logger.debug(self._last_error) + + def _on_ws_error(self, message: Any) -> None: + self._last_error = str(message) + logger.error("DashScope ASR error: {}", self._last_error) + + def _on_ws_event(self, message: Any) -> None: + payload = self._coerce_event(message) + event_type = str(payload.get("type") or "").strip() + if not event_type: + return + + if event_type in {"session.created", "session.updated"}: + self._session_ready.set() + return + if event_type == "error" or event_type.endswith(".failed"): + err_text = self._extract_text(payload, keys=("message", "error", "details")) + self._last_error = err_text or event_type + logger.error("DashScope ASR server event error: {}", self._last_error) + return + + if event_type == "conversation.item.input_audio_transcription.text": + stash_text = self._extract_text(payload, keys=("stash", "text", "transcript")) + self._emit_transcript(stash_text, is_final=False) + return + + if event_type == "conversation.item.input_audio_transcription.completed": + final_text = self._extract_text(payload, keys=("transcript", "text", "stash")) + self._emit_transcript(final_text, is_final=True) + return + + def _emit_transcript(self, text: str, *, is_final: bool) -> None: + normalized = str(text or "").strip() + if not normalized: + return + if not is_final and normalized == self._last_interim_text: + return + if not is_final: + self._last_interim_text = normalized + + if self._loop is None: + return + try: + asyncio.run_coroutine_threadsafe( + self._publish_transcript(normalized, is_final=is_final), + self._loop, + ) + except RuntimeError: + return + + async def _publish_transcript(self, text: str, *, is_final: bool) -> None: + await self._transcript_queue.put(ASRResult(text=text, is_final=is_final)) + if is_final: + await self._final_queue.put(text) + if self.on_transcript: + try: + await self.on_transcript(text, is_final) + except Exception as exc: + logger.warning("DashScope ASR transcript callback failed: {}", exc) + + @staticmethod + def _coerce_event(message: Any) -> Dict[str, Any]: + if isinstance(message, dict): + return message + if isinstance(message, str): + try: + parsed = json.loads(message) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + return {"type": "raw", "text": message} + return {"type": "raw", "text": str(message)} + + def _extract_text(self, payload: Dict[str, Any], *, keys: tuple[str, ...]) -> str: + for key in keys: + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(value, dict): + nested = self._extract_text(value, keys=keys) + if nested: + return nested + + for value in payload.values(): + if isinstance(value, dict): + nested = self._extract_text(value, keys=keys) + if nested: + return nested + return "" + + @staticmethod + def _drain_queue(queue: "asyncio.Queue[Any]") -> None: + while True: + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break diff --git a/engine/providers/asr/openai_compatible.py b/engine/providers/asr/openai_compatible.py index 1a2083b..cbff3e5 100644 --- a/engine/providers/asr/openai_compatible.py +++ b/engine/providers/asr/openai_compatible.py @@ -71,6 +71,7 @@ class OpenAICompatibleASRService(BaseASRService): on_transcript: Callback for transcription results (text, is_final) """ super().__init__(sample_rate=sample_rate, language=language) + self.mode = "offline" if not AIOHTTP_AVAILABLE: raise RuntimeError("aiohttp is required for OpenAICompatibleASRService") diff --git a/engine/providers/factory/default.py b/engine/providers/factory/default.py index 4294d3c..0d2912e 100644 --- a/engine/providers/factory/default.py +++ b/engine/providers/factory/default.py @@ -16,6 +16,7 @@ from runtime.ports import ( TTSServiceSpec, ) from providers.asr.buffered import BufferedASRService +from providers.asr.dashscope import DashScopeRealtimeASRService from providers.tts.dashscope import DashScopeTTSService from providers.llm.openai import MockLLMService, OpenAILLMService from providers.asr.openai_compatible import OpenAICompatibleASRService @@ -23,6 +24,7 @@ from providers.tts.openai_compatible import OpenAICompatibleTTSService from providers.tts.mock import MockTTSService _OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} +_DASHSCOPE_PROVIDERS = {"dashscope"} _SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS} @@ -31,6 +33,8 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): _DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" _DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime" + _DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + _DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime" _DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B" _DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" @@ -96,6 +100,16 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort: provider = self._normalize_provider(spec.provider) + if provider in _DASHSCOPE_PROVIDERS and spec.api_key: + return DashScopeRealtimeASRService( + api_key=spec.api_key, + api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL, + model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL, + sample_rate=spec.sample_rate, + language=spec.language, + on_transcript=spec.on_transcript, + ) + if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleASRService( api_key=spec.api_key, diff --git a/engine/runtime/pipeline/duplex.py b/engine/runtime/pipeline/duplex.py index aacd0c7..3a6bacc 100644 --- a/engine/runtime/pipeline/duplex.py +++ b/engine/runtime/pipeline/duplex.py @@ -30,11 +30,14 @@ from providers.factory.default import DefaultRealtimeServiceFactory from runtime.conversation import ConversationManager, ConversationState from runtime.events import get_event_bus from runtime.ports import ( + ASRMode, ASRPort, ASRServiceSpec, LLMPort, LLMServiceSpec, + OfflineASRPort, RealtimeServiceFactory, + StreamingASRPort, TTSPort, TTSServiceSpec, ) @@ -77,6 +80,7 @@ class DuplexPipeline: _ASR_DELTA_THROTTLE_MS = 500 _LLM_DELTA_THROTTLE_MS = 80 _ASR_CAPTURE_MAX_MS = 15000 + _ASR_STREAM_FINAL_TIMEOUT_MS = 800 _OPENER_PRE_ROLL_MS = 180 _DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = { "current_time": { @@ -317,6 +321,10 @@ class DuplexPipeline: self.llm_service = llm_service self.tts_service = tts_service self.asr_service = asr_service # Will be initialized in start() + self._asr_mode: ASRMode = self._resolve_asr_mode( + settings.asr_provider, + getattr(asr_service, "mode", None), + ) self._service_factory = service_factory or DefaultRealtimeServiceFactory() self._knowledge_searcher = knowledge_searcher self._tool_resource_resolver = tool_resource_resolver @@ -324,6 +332,7 @@ class DuplexPipeline: # Track last sent transcript to avoid duplicates self._last_sent_transcript = "" + self._latest_asr_interim_text = "" self._pending_transcript_delta: str = "" self._last_transcript_delta_emit_ms: float = 0.0 @@ -588,6 +597,7 @@ class DuplexPipeline: }, "asr": { "provider": asr_provider, + "mode": self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")), "model": str(self._runtime_asr.get("model") or settings.asr_model or ""), "interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms), "minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms), @@ -787,6 +797,22 @@ class DuplexPipeline: normalized = str(provider or "").strip().lower() return normalized == "dashscope" + @staticmethod + def _resolve_asr_mode(provider: Any, raw_mode: Any = None) -> ASRMode: + normalized_mode = str(raw_mode or "").strip().lower() + if normalized_mode in {"offline", "streaming"}: + return normalized_mode # type: ignore[return-value] + normalized_provider = str(provider or "").strip().lower() + if normalized_provider == "dashscope": + return "streaming" + return "offline" + + def _offline_asr(self) -> OfflineASRPort: + return self.asr_service # type: ignore[return-value] + + def _streaming_asr(self) -> StreamingASRPort: + return self.asr_service # type: ignore[return-value] + @staticmethod def _default_llm_base_url(provider: Any) -> Optional[str]: normalized = str(provider or "").strip().lower() @@ -967,11 +993,13 @@ class DuplexPipeline: asr_model = self._runtime_asr.get("model") or settings.asr_model asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms) asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms) + asr_mode = self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")) self.asr_service = self._service_factory.create_asr_service( ASRServiceSpec( provider=asr_provider, sample_rate=settings.sample_rate, + mode=asr_mode, language="auto", api_key=str(asr_api_key).strip() if asr_api_key else None, api_url=str(asr_api_url).strip() if asr_api_url else None, @@ -981,10 +1009,14 @@ class DuplexPipeline: on_transcript=self._on_transcript_callback, ) ) + self._asr_mode = self._resolve_asr_mode( + self._runtime_asr.get("provider") or settings.asr_provider, + getattr(self.asr_service, "mode", self._runtime_asr.get("mode")), + ) await self.asr_service.connect() - logger.info("DuplexPipeline services connected") + logger.info("DuplexPipeline services connected (asr_mode={})", self._asr_mode) if not self._outbound_task or self._outbound_task.done(): self._outbound_task = asyncio.create_task(self._outbound_loop()) @@ -1457,6 +1489,7 @@ class DuplexPipeline: self._last_sent_transcript = text if is_final: + self._latest_asr_interim_text = "" self._pending_transcript_delta = "" self._last_transcript_delta_emit_ms = 0.0 await self._send_event( @@ -1472,6 +1505,7 @@ class DuplexPipeline: logger.debug(f"Sent transcript (final): {text[:50]}...") return + self._latest_asr_interim_text = text self._pending_transcript_delta = text should_emit = ( self._last_transcript_delta_emit_ms <= 0.0 @@ -1495,14 +1529,16 @@ class DuplexPipeline: await self.conversation.start_user_turn() self._audio_buffer = b"" self._last_sent_transcript = "" + self._latest_asr_interim_text = "" self.eou_detector.reset() self._asr_capture_active = False self._asr_capture_started_ms = 0.0 self._pending_speech_audio = b"" - # Clear ASR buffer. Interim starts only after ASR capture is activated. - if hasattr(self.asr_service, 'clear_buffer'): - self.asr_service.clear_buffer() + if self._asr_mode == "streaming": + self._streaming_asr().clear_utterance() + else: + self._offline_asr().clear_buffer() logger.debug("User speech started") @@ -1511,8 +1547,10 @@ class DuplexPipeline: if self._asr_capture_active: return - if hasattr(self.asr_service, 'start_interim_transcription'): - await self.asr_service.start_interim_transcription() + if self._asr_mode == "streaming": + await self._streaming_asr().begin_utterance() + else: + await self._offline_asr().start_interim_transcription() # Prime ASR with a short pre-speech context window so the utterance # start isn't lost while waiting for VAD to transition to Speech. @@ -1545,24 +1583,22 @@ class DuplexPipeline: self._pending_speech_audio = b"" return - # Add a tiny trailing silence tail to stabilize final-token decoding. - if self._asr_final_tail_bytes > 0: - final_tail = b"\x00" * self._asr_final_tail_bytes - await self.asr_service.send_audio(final_tail) - - # Stop interim transcriptions - if hasattr(self.asr_service, 'stop_interim_transcription'): - await self.asr_service.stop_interim_transcription() - - # Get final transcription from ASR service user_text = "" - - if hasattr(self.asr_service, 'get_final_transcription'): - # SiliconFlow ASR - get final transcription - user_text = await self.asr_service.get_final_transcription() - elif hasattr(self.asr_service, 'get_and_clear_text'): - # Buffered ASR - get accumulated text - user_text = self.asr_service.get_and_clear_text() + if self._asr_mode == "streaming": + streaming_asr = self._streaming_asr() + await streaming_asr.end_utterance() + user_text = await streaming_asr.wait_for_final_transcription( + timeout_ms=self._ASR_STREAM_FINAL_TIMEOUT_MS + ) + if not user_text.strip(): + user_text = self._latest_asr_interim_text + else: + # Add a tiny trailing silence tail to stabilize final-token decoding. + if self._asr_final_tail_bytes > 0: + final_tail = b"\x00" * self._asr_final_tail_bytes + await self.asr_service.send_audio(final_tail) + await self._offline_asr().stop_interim_transcription() + user_text = await self._offline_asr().get_final_transcription() # Skip if no meaningful text if not user_text or not user_text.strip(): @@ -1570,6 +1606,7 @@ class DuplexPipeline: # Reset for next utterance self._audio_buffer = b"" self._last_sent_transcript = "" + self._latest_asr_interim_text = "" self._asr_capture_active = False self._asr_capture_started_ms = 0.0 self._pending_speech_audio = b"" @@ -1594,6 +1631,7 @@ class DuplexPipeline: # Clear buffers self._audio_buffer = b"" self._last_sent_transcript = "" + self._latest_asr_interim_text = "" self._pending_transcript_delta = "" self._last_transcript_delta_emit_ms = 0.0 self._asr_capture_active = False diff --git a/engine/runtime/ports/__init__.py b/engine/runtime/ports/__init__.py index a7cbce3..26319b2 100644 --- a/engine/runtime/ports/__init__.py +++ b/engine/runtime/ports/__init__.py @@ -1,6 +1,12 @@ """Port interfaces for runtime integration boundaries.""" -from runtime.ports.asr import ASRBufferControl, ASRInterimControl, ASRPort, ASRServiceSpec +from runtime.ports.asr import ( + ASRMode, + ASRPort, + ASRServiceSpec, + OfflineASRPort, + StreamingASRPort, +) from runtime.ports.control_plane import ( AssistantRuntimeConfigProvider, ControlPlaneGateway, @@ -13,10 +19,11 @@ from runtime.ports.service_factory import RealtimeServiceFactory from runtime.ports.tts import TTSPort, TTSServiceSpec __all__ = [ + "ASRMode", "ASRPort", "ASRServiceSpec", - "ASRInterimControl", - "ASRBufferControl", + "OfflineASRPort", + "StreamingASRPort", "AssistantRuntimeConfigProvider", "ControlPlaneGateway", "ConversationHistoryStore", diff --git a/engine/runtime/ports/asr.py b/engine/runtime/ports/asr.py index 8621ed0..7da547f 100644 --- a/engine/runtime/ports/asr.py +++ b/engine/runtime/ports/asr.py @@ -3,11 +3,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import AsyncIterator, Awaitable, Callable, Optional, Protocol +from typing import AsyncIterator, Awaitable, Callable, Literal, Optional, Protocol from providers.common.base import ASRResult TranscriptCallback = Callable[[str, bool], Awaitable[None]] +ASRMode = Literal["offline", "streaming"] @dataclass(frozen=True) @@ -16,6 +17,7 @@ class ASRServiceSpec: provider: str sample_rate: int + mode: Optional[ASRMode] = None language: str = "auto" api_key: Optional[str] = None api_url: Optional[str] = None @@ -28,6 +30,8 @@ class ASRServiceSpec: class ASRPort(Protocol): """Port for speech recognition providers.""" + mode: ASRMode + async def connect(self) -> None: """Establish connection to ASR provider.""" @@ -41,18 +45,16 @@ class ASRPort(Protocol): """Stream partial/final recognition results.""" -class ASRInterimControl(Protocol): - """Optional extension for explicit interim transcription control.""" +class OfflineASRPort(ASRPort, Protocol): + """Port for offline/buffered ASR providers.""" + + mode: Literal["offline"] async def start_interim_transcription(self) -> None: - """Start interim transcription loop if supported.""" + """Start interim transcription loop.""" async def stop_interim_transcription(self) -> None: - """Stop interim transcription loop if supported.""" - - -class ASRBufferControl(Protocol): - """Optional extension for explicit ASR buffer lifecycle control.""" + """Stop interim transcription loop.""" def clear_buffer(self) -> None: """Clear provider-side ASR buffer.""" @@ -62,3 +64,21 @@ class ASRBufferControl(Protocol): def get_and_clear_text(self) -> str: """Return buffered text and clear internal state.""" + + +class StreamingASRPort(ASRPort, Protocol): + """Port for streaming ASR providers.""" + + mode: Literal["streaming"] + + async def begin_utterance(self) -> None: + """Start a new utterance stream.""" + + async def end_utterance(self) -> None: + """Signal end of current utterance stream.""" + + async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str: + """Wait for final transcript after utterance end.""" + + def clear_utterance(self) -> None: + """Reset utterance-local state.""" diff --git a/engine/tests/test_asr_factory_modes.py b/engine/tests/test_asr_factory_modes.py new file mode 100644 index 0000000..c127399 --- /dev/null +++ b/engine/tests/test_asr_factory_modes.py @@ -0,0 +1,46 @@ +from providers.asr.buffered import BufferedASRService +from providers.asr.dashscope import DashScopeRealtimeASRService +from providers.asr.openai_compatible import OpenAICompatibleASRService +from providers.factory.default import DefaultRealtimeServiceFactory +from runtime.ports import ASRServiceSpec + + +def test_create_asr_service_dashscope_returns_streaming_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_asr_service( + ASRServiceSpec( + provider="dashscope", + mode="streaming", + sample_rate=16000, + api_key="test-key", + model="qwen3-asr-flash-realtime", + ) + ) + assert isinstance(service, DashScopeRealtimeASRService) + assert service.mode == "streaming" + + +def test_create_asr_service_openai_compatible_returns_offline_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_asr_service( + ASRServiceSpec( + provider="openai_compatible", + sample_rate=16000, + api_key="test-key", + model="FunAudioLLM/SenseVoiceSmall", + ) + ) + assert isinstance(service, OpenAICompatibleASRService) + assert service.mode == "offline" + + +def test_create_asr_service_fallback_buffered_for_unsupported_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_asr_service( + ASRServiceSpec( + provider="unknown_provider", + sample_rate=16000, + ) + ) + assert isinstance(service, BufferedASRService) + assert service.mode == "offline" diff --git a/engine/tests/test_dashscope_asr_provider.py b/engine/tests/test_dashscope_asr_provider.py new file mode 100644 index 0000000..123530a --- /dev/null +++ b/engine/tests/test_dashscope_asr_provider.py @@ -0,0 +1,67 @@ +import asyncio + +import pytest + +from providers.asr.dashscope import DashScopeRealtimeASRService + + +@pytest.mark.asyncio +async def test_dashscope_asr_interim_event_emits_interim_transcript(): + received = [] + + async def _on_transcript(text: str, is_final: bool) -> None: + received.append((text, is_final)) + + service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript) + service._loop = asyncio.get_running_loop() + service._running = True + + service._on_ws_event( + { + "type": "conversation.item.input_audio_transcription.text", + "stash": "你好世界", + } + ) + await asyncio.sleep(0.05) + + result = service._transcript_queue.get_nowait() + assert result.text == "你好世界" + assert result.is_final is False + assert received == [("你好世界", False)] + + +@pytest.mark.asyncio +async def test_dashscope_asr_final_event_emits_final_transcript_and_final_queue(): + received = [] + + async def _on_transcript(text: str, is_final: bool) -> None: + received.append((text, is_final)) + + service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript) + service._loop = asyncio.get_running_loop() + service._running = True + service._audio_sent_in_utterance = True + + service._on_ws_event( + { + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "最终识别结果", + } + ) + await asyncio.sleep(0.05) + + result = service._transcript_queue.get_nowait() + assert result.text == "最终识别结果" + assert result.is_final is True + assert service._final_queue.get_nowait() == "最终识别结果" + assert received == [("最终识别结果", True)] + + +@pytest.mark.asyncio +async def test_dashscope_wait_for_final_falls_back_to_latest_interim_on_timeout(): + service = DashScopeRealtimeASRService(api_key="test-key") + service._audio_sent_in_utterance = True + service._last_interim_text = "部分结果" + + text = await service.wait_for_final_transcription(timeout_ms=10) + assert text == "部分结果" diff --git a/engine/tests/test_duplex_asr_modes.py b/engine/tests/test_duplex_asr_modes.py new file mode 100644 index 0000000..76af160 --- /dev/null +++ b/engine/tests/test_duplex_asr_modes.py @@ -0,0 +1,196 @@ +import asyncio +from typing import Any, Dict, List + +import pytest + +from runtime.pipeline.duplex import DuplexPipeline + + +class _DummySileroVAD: + def __init__(self, *args, **kwargs): + pass + + def process_audio(self, _pcm: bytes) -> float: + return 0.0 + + +class _DummyVADProcessor: + def __init__(self, *args, **kwargs): + pass + + def process(self, _speech_prob: float): + return "Silence", 0.0 + + +class _DummyEouDetector: + def __init__(self, *args, **kwargs): + self.is_speaking = True + + def process(self, _vad_status: str, force_eligible: bool = False) -> bool: + _ = force_eligible + return False + + def reset(self) -> None: + self.is_speaking = False + + +class _FakeTransport: + async def send_event(self, _event: Dict[str, Any]) -> None: + return None + + async def send_audio(self, _audio: bytes) -> None: + return None + + +class _FakeStreamingASR: + mode = "streaming" + + def __init__(self): + self.begin_calls = 0 + self.end_calls = 0 + self.wait_calls = 0 + self.sent_audio: List[bytes] = [] + self.wait_text = "" + + async def connect(self) -> None: + return None + + async def disconnect(self) -> None: + return None + + async def send_audio(self, audio: bytes) -> None: + self.sent_audio.append(audio) + + async def receive_transcripts(self): + if False: + yield None + + async def begin_utterance(self) -> None: + self.begin_calls += 1 + + async def end_utterance(self) -> None: + self.end_calls += 1 + + async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str: + _ = timeout_ms + self.wait_calls += 1 + return self.wait_text + + def clear_utterance(self) -> None: + return None + + +class _FakeOfflineASR: + mode = "offline" + + def __init__(self): + self.start_interim_calls = 0 + self.stop_interim_calls = 0 + self.sent_audio: List[bytes] = [] + self.final_text = "offline final" + + async def connect(self) -> None: + return None + + async def disconnect(self) -> None: + return None + + async def send_audio(self, audio: bytes) -> None: + self.sent_audio.append(audio) + + async def receive_transcripts(self): + if False: + yield None + + async def start_interim_transcription(self) -> None: + self.start_interim_calls += 1 + + async def stop_interim_transcription(self) -> None: + self.stop_interim_calls += 1 + + async def get_final_transcription(self) -> str: + return self.final_text + + def clear_buffer(self) -> None: + return None + + def get_and_clear_text(self) -> str: + return self.final_text + + +def _build_pipeline(monkeypatch, asr_service): + monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD) + monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor) + monkeypatch.setattr("runtime.pipeline.duplex.EouDetector", _DummyEouDetector) + return DuplexPipeline( + transport=_FakeTransport(), + session_id="asr_mode_test", + asr_service=asr_service, + ) + + +@pytest.mark.asyncio +async def test_start_asr_capture_uses_streaming_begin(monkeypatch): + asr = _FakeStreamingASR() + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "streaming" + pipeline._pending_speech_audio = b"\x00" * 320 + pipeline._pre_speech_buffer = b"\x00" * 640 + + await pipeline._start_asr_capture() + + assert asr.begin_calls == 1 + assert asr.sent_audio + assert pipeline._asr_capture_active is True + + +@pytest.mark.asyncio +async def test_start_asr_capture_uses_offline_interim_control(monkeypatch): + asr = _FakeOfflineASR() + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "offline" + pipeline._pending_speech_audio = b"\x00" * 320 + pipeline._pre_speech_buffer = b"\x00" * 640 + + await pipeline._start_asr_capture() + + assert asr.start_interim_calls == 1 + assert asr.sent_audio + assert pipeline._asr_capture_active is True + + +@pytest.mark.asyncio +async def test_streaming_eou_falls_back_to_latest_interim(monkeypatch): + asr = _FakeStreamingASR() + asr.wait_text = "" + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "streaming" + pipeline._asr_capture_active = True + pipeline._latest_asr_interim_text = "fallback interim text" + await pipeline.conversation.start_user_turn() + + captured_events = [] + captured_turns = [] + + async def _capture_event(event: Dict[str, Any], priority: int = 20): + _ = priority + captured_events.append(event) + + async def _noop_stop_current_speech() -> None: + return None + + async def _capture_turn(user_text: str, *args, **kwargs) -> None: + _ = (args, kwargs) + captured_turns.append(user_text) + + monkeypatch.setattr(pipeline, "_send_event", _capture_event) + monkeypatch.setattr(pipeline, "_stop_current_speech", _noop_stop_current_speech) + monkeypatch.setattr(pipeline, "_handle_turn", _capture_turn) + + await pipeline._on_end_of_utterance() + await asyncio.sleep(0.05) + + assert asr.end_calls == 1 + assert asr.wait_calls == 1 + assert captured_turns == ["fallback interim text"] + assert any(event.get("type") == "transcript.final" for event in captured_events) diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index d820643..717f96a 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -52,9 +52,33 @@ class _FakeTTS: class _FakeASR: + mode = "offline" + async def connect(self) -> None: return None + async def disconnect(self) -> None: + return None + + async def send_audio(self, _audio: bytes) -> None: + return None + + async def receive_transcripts(self): + if False: + yield None + + def clear_buffer(self) -> None: + return None + + async def start_interim_transcription(self) -> None: + return None + + async def stop_interim_transcription(self) -> None: + return None + + async def get_final_transcription(self) -> str: + return "" + class _FakeLLM: def __init__(self, rounds: List[List[LLMStreamEvent]]):