diff --git a/api/app/models.py b/api/app/models.py index 29579f2..265f83d 100644 --- a/api/app/models.py +++ b/api/app/models.py @@ -127,6 +127,7 @@ class Assistant(Base): speed: Mapped[float] = mapped_column(Float, default=1.0) hotwords: Mapped[dict] = mapped_column(JSON, default=list) tools: Mapped[dict] = mapped_column(JSON, default=list) + asr_interim_enabled: Mapped[bool] = mapped_column(default=False) bot_cannot_be_interrupted: Mapped[bool] = mapped_column(default=False) interruption_sensitivity: Mapped[int] = mapped_column(Integer, default=500) config_mode: Mapped[str] = mapped_column(String(32), default="platform") diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index bf43303..f517458 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -126,6 +126,9 @@ def _ensure_assistant_schema(db: Session) -> None: if "manual_opener_tool_calls" not in columns: db.execute(text("ALTER TABLE assistants ADD COLUMN manual_opener_tool_calls JSON")) altered = True + if "asr_interim_enabled" not in columns: + db.execute(text("ALTER TABLE assistants ADD COLUMN asr_interim_enabled BOOLEAN DEFAULT 0")) + altered = True if altered: db.commit() @@ -317,18 +320,27 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s else: warnings.append(f"LLM model not found: {assistant.llm_model_id}") + asr_runtime: Dict[str, Any] = { + "enableInterim": bool(assistant.asr_interim_enabled), + } if assistant.asr_model_id: asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first() if asr: - asr_provider = "openai_compatible" if _is_openai_compatible_vendor(asr.vendor) else "buffered" - metadata["services"]["asr"] = { + if _is_dashscope_vendor(asr.vendor): + asr_provider = "dashscope" + elif _is_openai_compatible_vendor(asr.vendor): + asr_provider = "openai_compatible" + else: + asr_provider = "buffered" + asr_runtime.update({ "provider": asr_provider, "model": asr.model_name or asr.name, - "apiKey": asr.api_key if asr_provider == "openai_compatible" else None, - "baseUrl": asr.base_url if asr_provider == "openai_compatible" else None, - } + "apiKey": asr.api_key if asr_provider in {"openai_compatible", "dashscope"} else None, + "baseUrl": asr.base_url if asr_provider in {"openai_compatible", "dashscope"} else None, + }) else: warnings.append(f"ASR model not found: {assistant.asr_model_id}") + metadata["services"]["asr"] = asr_runtime if not assistant.voice_output_enabled: metadata["services"]["tts"] = {"enabled": False} @@ -432,6 +444,7 @@ def assistant_to_dict(assistant: Assistant) -> dict: "speed": assistant.speed, "hotwords": assistant.hotwords or [], "tools": _normalize_assistant_tool_ids(assistant.tools), + "asrInterimEnabled": bool(assistant.asr_interim_enabled), "botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted), "interruptionSensitivity": assistant.interruption_sensitivity, "configMode": assistant.config_mode, @@ -452,6 +465,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None: "firstTurnMode": "first_turn_mode", "manualOpenerToolCalls": "manual_opener_tool_calls", "interruptionSensitivity": "interruption_sensitivity", + "asrInterimEnabled": "asr_interim_enabled", "botCannotBeInterrupted": "bot_cannot_be_interrupted", "configMode": "config_mode", "voiceOutputEnabled": "voice_output_enabled", @@ -646,6 +660,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): speed=data.speed, hotwords=data.hotwords, tools=_normalize_assistant_tool_ids(data.tools), + asr_interim_enabled=data.asrInterimEnabled, bot_cannot_be_interrupted=data.botCannotBeInterrupted, interruption_sensitivity=data.interruptionSensitivity, config_mode=data.configMode, diff --git a/api/app/schemas.py b/api/app/schemas.py index 9bf2274..f0ad0c3 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -291,6 +291,7 @@ class AssistantBase(BaseModel): speed: float = 1.0 hotwords: List[str] = [] tools: List[str] = [] + asrInterimEnabled: bool = False botCannotBeInterrupted: bool = False interruptionSensitivity: int = 500 configMode: str = "platform" @@ -322,6 +323,7 @@ class AssistantUpdate(BaseModel): speed: Optional[float] = None hotwords: Optional[List[str]] = None tools: Optional[List[str]] = None + asrInterimEnabled: Optional[bool] = None botCannotBeInterrupted: Optional[bool] = None interruptionSensitivity: Optional[int] = None configMode: Optional[str] = None diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index 0d880ef..7acbd30 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -27,6 +27,7 @@ class TestAssistantAPI: assert data["voiceOutputEnabled"] is True assert data["firstTurnMode"] == "bot_first" assert data["generatedOpenerEnabled"] is False + assert data["asrInterimEnabled"] is False assert data["botCannotBeInterrupted"] is False assert "id" in data assert data["callCount"] == 0 @@ -37,6 +38,7 @@ class TestAssistantAPI: response = client.post("/api/assistants", json=data) assert response.status_code == 200 assert response.json()["name"] == "Minimal Assistant" + assert response.json()["asrInterimEnabled"] is False def test_get_assistant_by_id(self, client, sample_assistant_data): """Test getting a specific assistant by ID""" @@ -68,6 +70,7 @@ class TestAssistantAPI: "prompt": "You are an updated assistant.", "speed": 1.5, "voiceOutputEnabled": False, + "asrInterimEnabled": True, "manualOpenerToolCalls": [ {"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}} ], @@ -79,6 +82,7 @@ class TestAssistantAPI: assert data["prompt"] == "You are an updated assistant." assert data["speed"] == 1.5 assert data["voiceOutputEnabled"] is False + assert data["asrInterimEnabled"] is True assert data["manualOpenerToolCalls"] == [ {"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}} ] @@ -213,6 +217,7 @@ class TestAssistantAPI: "prompt": "runtime prompt", "opener": "runtime opener", "manualOpenerToolCalls": [{"toolName": "text_msg_prompt", "arguments": {"msg": "欢迎"}}], + "asrInterimEnabled": True, "speed": 1.1, }) assistant_resp = client.post("/api/assistants", json=sample_assistant_data) @@ -232,6 +237,7 @@ class TestAssistantAPI: assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"] assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"] assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"] + assert metadata["services"]["asr"]["enableInterim"] is True expected_tts_voice = f"{sample_voice_data['model']}:{sample_voice_data['voice_key']}" assert metadata["services"]["tts"]["voice"] == expected_tts_voice assert metadata["services"]["tts"]["baseUrl"] == sample_voice_data["base_url"] @@ -309,6 +315,7 @@ class TestAssistantAPI: assert runtime_resp.status_code == 200 metadata = runtime_resp.json()["sessionStartMetadata"] assert metadata["output"]["mode"] == "text" + assert metadata["services"]["asr"]["enableInterim"] is False assert metadata["services"]["tts"]["enabled"] is False def test_runtime_config_dashscope_voice_provider(self, client, sample_assistant_data): @@ -343,6 +350,48 @@ class TestAssistantAPI: assert tts["apiKey"] == "dashscope-key" assert tts["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + def test_runtime_config_dashscope_asr_provider(self, client, sample_assistant_data): + """DashScope ASR models should map to dashscope asr provider in runtime metadata.""" + asr_resp = client.post("/api/asr", json={ + "name": "DashScope Realtime ASR", + "vendor": "DashScope", + "language": "zh", + "base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime", + "api_key": "dashscope-asr-key", + "model_name": "qwen3-asr-flash-realtime", + "hotwords": [], + "enable_punctuation": True, + "enable_normalization": True, + "enabled": True, + }) + assert asr_resp.status_code == 200 + asr_payload = asr_resp.json() + + sample_assistant_data.update({ + "asrModelId": asr_payload["id"], + }) + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + asr = metadata["services"]["asr"] + assert asr["provider"] == "dashscope" + assert asr["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + assert asr["enableInterim"] is False + + def test_runtime_config_defaults_asr_interim_disabled_without_asr_model(self, client, sample_assistant_data): + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + assert metadata["services"]["asr"]["enableInterim"] is False + def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data): sample_assistant_data.update({ "firstTurnMode": "user_first", diff --git a/docs/content/customization/asr.md b/docs/content/customization/asr.md index 56f51cc..74e1097 100644 --- a/docs/content/customization/asr.md +++ b/docs/content/customization/asr.md @@ -12,6 +12,24 @@ | **热词** | 提高业务词汇、品牌词、专有名词识别率 | | **标点与规范化** | 自动补全标点、规范数字和日期等 | +## 模式 + +- `offline`:引擎本地缓冲音频后触发识别(适用于 OpenAI-compatible / SiliconFlow)。 +- `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR、Volcengine BigASR)。 + +## 配置项 + +| 配置项 | 说明 | +|---|---| +| ASR 引擎 | 选择语音识别服务提供商 | +| 模型 | 识别模型名称 | +| `enable_interim` | 是否开启离线 ASR 中间结果(默认 `false`,仅离线模式生效) | +| `app_id` / `resource_id` | Volcengine 等厂商的应用标识与资源标识 | +| `request_params` | 厂商原生请求参数透传,例如 `end_window_size`、`force_to_speech_time`、`context` | +| 语言 | 中文/英文/多语言 | +| 热词 | 提升特定词汇识别准确率 | +| 标点与规范化 | 是否自动补全标点、文本规范化 | + ## 选择建议 - 客服、外呼等业务场景建议维护热词表,并按业务线持续更新 @@ -29,3 +47,7 @@ - [声音资源](voices.md) - 完整语音输入输出链路中的 TTS 侧配置 - [快速开始](../quickstart/index.md) - 以任务路径接入第一个 ASR 资源 +- 客服场景建议开启热词并维护业务词表 +- 多语言场景建议按会话入口显式指定语言 +- 对延迟敏感场景优先选择流式识别模型 +- 当前支持提供商:`openai_compatible`、`siliconflow`、`dashscope`、`volcengine`、`buffered`(回退) diff --git a/engine/adapters/control_plane/backend.py b/engine/adapters/control_plane/backend.py index 087f744..9f8914d 100644 --- a/engine/adapters/control_plane/backend.py +++ b/engine/adapters/control_plane/backend.py @@ -230,6 +230,14 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter): tts_runtime["baseUrl"] = cls._as_str(tts.get("api_url")) if cls._as_str(tts.get("voice")): tts_runtime["voice"] = cls._as_str(tts.get("voice")) + if cls._as_str(tts.get("app_id")): + tts_runtime["appId"] = cls._as_str(tts.get("app_id")) + if cls._as_str(tts.get("resource_id")): + tts_runtime["resourceId"] = cls._as_str(tts.get("resource_id")) + if cls._as_str(tts.get("cluster")): + tts_runtime["cluster"] = cls._as_str(tts.get("cluster")) + if cls._as_str(tts.get("uid")): + tts_runtime["uid"] = cls._as_str(tts.get("uid")) if tts.get("speed") is not None: tts_runtime["speed"] = tts.get("speed") dashscope_mode = cls._as_str(tts.get("dashscope_mode")) or cls._as_str(tts.get("mode")) @@ -249,6 +257,18 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter): asr_runtime["apiKey"] = cls._as_str(asr.get("api_key")) if cls._as_str(asr.get("api_url")): asr_runtime["baseUrl"] = cls._as_str(asr.get("api_url")) + if cls._as_str(asr.get("app_id")): + asr_runtime["appId"] = cls._as_str(asr.get("app_id")) + if cls._as_str(asr.get("resource_id")): + asr_runtime["resourceId"] = cls._as_str(asr.get("resource_id")) + if cls._as_str(asr.get("cluster")): + asr_runtime["cluster"] = cls._as_str(asr.get("cluster")) + if cls._as_str(asr.get("uid")): + asr_runtime["uid"] = cls._as_str(asr.get("uid")) + if isinstance(asr.get("request_params"), dict): + asr_runtime["requestParams"] = dict(asr.get("request_params") or {}) + if asr.get("enable_interim") is not None: + asr_runtime["enableInterim"] = asr.get("enable_interim") if asr.get("interim_interval_ms") is not None: asr_runtime["interimIntervalMs"] = asr.get("interim_interval_ms") if asr.get("min_audio_ms") is not None: diff --git a/engine/app/config.py b/engine/app/config.py index 233ba75..8edf7ce 100644 --- a/engine/app/config.py +++ b/engine/app/config.py @@ -71,11 +71,15 @@ class Settings(BaseSettings): # TTS Configuration tts_provider: str = Field( default="openai_compatible", - description="TTS provider (openai_compatible, siliconflow, dashscope)" + description="TTS provider (openai_compatible, siliconflow, dashscope, volcengine)" ) tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL") tts_model: Optional[str] = Field(default=None, description="TTS model name") tts_voice: str = Field(default="anna", description="TTS voice name") + tts_app_id: Optional[str] = Field(default=None, description="Provider-specific TTS app ID") + tts_resource_id: Optional[str] = Field(default=None, description="Provider-specific TTS resource ID") + tts_cluster: Optional[str] = Field(default=None, description="Provider-specific TTS cluster") + tts_uid: Optional[str] = Field(default=None, description="Provider-specific TTS user ID") tts_mode: str = Field( default="commit", description="DashScope-only TTS mode (commit, server_commit). Ignored for non-dashscope providers." @@ -85,10 +89,19 @@ class Settings(BaseSettings): # ASR Configuration asr_provider: str = Field( default="openai_compatible", - description="ASR provider (openai_compatible, buffered, siliconflow)" + description="ASR provider (openai_compatible, buffered, siliconflow, dashscope, volcengine)" ) asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL") asr_model: Optional[str] = Field(default=None, description="ASR model name") + asr_app_id: Optional[str] = Field(default=None, description="Provider-specific ASR app ID") + asr_resource_id: Optional[str] = Field(default=None, description="Provider-specific ASR resource ID") + asr_cluster: Optional[str] = Field(default=None, description="Provider-specific ASR cluster") + asr_uid: Optional[str] = Field(default=None, description="Provider-specific ASR user ID") + asr_request_params_json: Optional[str] = Field( + default=None, + description="Provider-specific ASR request params as JSON string" + ) + asr_enable_interim: bool = Field(default=False, description="Enable interim transcripts for offline ASR") asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms") asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result") asr_start_min_speech_ms: int = Field( diff --git a/engine/config/agents/dashscope.yaml b/engine/config/agents/dashscope.yaml new file mode 100644 index 0000000..3491d68 --- /dev/null +++ b/engine/config/agents/dashscope.yaml @@ -0,0 +1,47 @@ +# Agent behavior configuration for DashScope realtime ASR/TTS. +# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers). +# Infra/server/network settings should stay in .env. + +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.5 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + # provider: openai | openai_compatible | siliconflow + provider: openai_compatible + model: deepseek-v3 + temperature: 0.7 + api_key: your_llm_api_key + api_url: https://api.qnaigc.com/v1 + + tts: + provider: dashscope + api_key: your_tts_api_key + api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + model: qwen3-tts-flash-realtime + voice: Cherry + dashscope_mode: commit + speed: 1.0 + + asr: + provider: dashscope + api_key: your_asr_api_key + api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + model: qwen3-asr-flash-realtime + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational. + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 diff --git a/engine/config/agents/example.yaml b/engine/config/agents/example.yaml index 70f4933..e68b6f3 100644 --- a/engine/config/agents/example.yaml +++ b/engine/config/agents/example.yaml @@ -21,12 +21,17 @@ agent: api_url: https://api.qnaigc.com/v1 tts: - # provider: openai_compatible | siliconflow | dashscope + # provider: openai_compatible | siliconflow | dashscope | volcengine # dashscope defaults (if omitted): # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # model: qwen3-tts-flash-realtime # dashscope_mode: commit (engine splits) | server_commit (dashscope splits) # note: dashscope_mode/mode is ONLY used when provider=dashscope. + # volcengine defaults (if omitted): + # api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional + # resource_id: seed-tts-2.0 + # app_id: your volcengine app key + # api_key: your volcengine access key provider: openai_compatible api_key: your_tts_api_key api_url: https://api.siliconflow.cn/v1/audio/speech @@ -35,11 +40,26 @@ agent: speed: 1.0 asr: - # provider: buffered | openai_compatible | siliconflow + # provider: buffered | openai_compatible | siliconflow | dashscope | volcengine + # dashscope defaults (if omitted): + # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + # model: qwen3-asr-flash-realtime + # note: dashscope uses streaming ASR mode (chunk-by-chunk). + # volcengine defaults (if omitted): + # api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel + # model: bigmodel + # resource_id: volc.bigasr.sauc.duration + # app_id: your volcengine app key + # api_key: your volcengine access key + # request_params: + # end_window_size: 800 + # force_to_speech_time: 1000 + # note: volcengine uses streaming ASR mode (chunk-by-chunk). provider: openai_compatible api_key: you_asr_api_key api_url: https://api.siliconflow.cn/v1/audio/transcriptions model: FunAudioLLM/SenseVoiceSmall + enable_interim: false interim_interval_ms: 500 min_audio_ms: 300 start_min_speech_ms: 160 diff --git a/engine/config/agents/tools.yaml b/engine/config/agents/tools.yaml index e2968bb..11cd7c3 100644 --- a/engine/config/agents/tools.yaml +++ b/engine/config/agents/tools.yaml @@ -18,12 +18,17 @@ agent: api_url: https://api.qnaigc.com/v1 tts: - # provider: openai_compatible | siliconflow | dashscope + # provider: openai_compatible | siliconflow | dashscope | volcengine # dashscope defaults (if omitted): # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime # model: qwen3-tts-flash-realtime # dashscope_mode: commit (engine splits) | server_commit (dashscope splits) # note: dashscope_mode/mode is ONLY used when provider=dashscope. + # volcengine defaults (if omitted): + # api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional + # resource_id: seed-tts-2.0 + # app_id: your volcengine app key + # api_key: your volcengine access key provider: openai_compatible api_key: your_tts_api_key api_url: https://api.siliconflow.cn/v1/audio/speech @@ -32,11 +37,26 @@ agent: speed: 1.0 asr: - # provider: buffered | openai_compatible | siliconflow + # provider: buffered | openai_compatible | siliconflow | dashscope | volcengine + # dashscope defaults (if omitted): + # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + # model: qwen3-asr-flash-realtime + # note: dashscope uses streaming ASR mode (chunk-by-chunk). + # volcengine defaults (if omitted): + # api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel + # model: bigmodel + # resource_id: volc.bigasr.sauc.duration + # app_id: your volcengine app key + # api_key: your volcengine access key + # request_params: + # end_window_size: 800 + # force_to_speech_time: 1000 + # note: volcengine uses streaming ASR mode (chunk-by-chunk). provider: openai_compatible api_key: your_asr_api_key api_url: https://api.siliconflow.cn/v1/audio/transcriptions model: FunAudioLLM/SenseVoiceSmall + enable_interim: false interim_interval_ms: 500 min_audio_ms: 300 start_min_speech_ms: 160 diff --git a/engine/config/agents/volcengine.yaml b/engine/config/agents/volcengine.yaml new file mode 100644 index 0000000..acd66b3 --- /dev/null +++ b/engine/config/agents/volcengine.yaml @@ -0,0 +1,68 @@ +# Agent behavior configuration (safe to edit per profile) +# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers). +# Infra/server/network settings should stay in .env. + +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.5 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + # provider: openai | openai_compatible | siliconflow + provider: openai_compatible + model: deepseek-v3 + temperature: 0.7 + # Required: no fallback. You can still reference env explicitly. + api_key: your_llm_api_key + # Optional for OpenAI-compatible endpoints: + api_url: https://api.qnaigc.com/v1 + + tts: + # provider: edge | openai_compatible | siliconflow | dashscope + # dashscope defaults (if omitted): + # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + # model: qwen3-tts-flash-realtime + # dashscope_mode: commit (engine splits) | server_commit (dashscope splits) + # note: dashscope_mode/mode is ONLY used when provider=dashscope. + # volcengine defaults (if omitted): + provider: volcengine + api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional + resource_id: seed-tts-2.0 + app_id: your_tts_app_id + api_key: your_tts_api_key + speed: 1.1 + voice: zh_female_vv_uranus_bigtts + + asr: + asr: + provider: volcengine + api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel + app_id: your_asr_app_id + api_key: your_asr_api_key + resource_id: volc.bigasr.sauc.duration + uid: caller-1 + model: bigmodel + request_params: + end_window_size: 800 + force_to_speech_time: 1000 + enable_punc: true + enable_itn: false + enable_ddc: false + show_utterance: true + result_type: single + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: 你是一个人工智能助手,你用简答语句回答,避免使用标点符号和emoji。 + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 diff --git a/engine/data/audio_examples/single_utterance_16k.wav b/engine/data/audio_examples/single_utterance_16k.wav deleted file mode 100644 index 8c7bbe5..0000000 Binary files a/engine/data/audio_examples/single_utterance_16k.wav and /dev/null differ diff --git a/engine/data/audio_examples/three_utterances.wav b/engine/data/audio_examples/three_utterances.wav deleted file mode 100644 index c2dca2f..0000000 Binary files a/engine/data/audio_examples/three_utterances.wav and /dev/null differ diff --git a/engine/data/audio_examples/two_utterances.wav b/engine/data/audio_examples/two_utterances.wav deleted file mode 100644 index 5c66f70..0000000 Binary files a/engine/data/audio_examples/two_utterances.wav and /dev/null differ diff --git a/engine/docs/extension_ports.md b/engine/docs/extension_ports.md index 8566194..c0f65f6 100644 --- a/engine/docs/extension_ports.md +++ b/engine/docs/extension_ports.md @@ -20,7 +20,7 @@ This document defines the draft port set used to keep core runtime extensible. - `runtime/ports/asr.py` - `ASRServiceSpec` - `ASRPort` - - optional extensions: `ASRInterimControl`, `ASRBufferControl` + - explicit mode ports: `OfflineASRPort`, `StreamingASRPort` - `runtime/ports/service_factory.py` - `RealtimeServiceFactory` @@ -36,10 +36,10 @@ This document defines the draft port set used to keep core runtime extensible. - supported providers: `openai`, `openai_compatible`, `openai-compatible`, `siliconflow` - fallback: `MockLLMService` - TTS: - - supported providers: `dashscope`, `openai_compatible`, `openai-compatible`, `siliconflow` + - supported providers: `dashscope`, `volcengine`, `openai_compatible`, `openai-compatible`, `siliconflow` - fallback: `MockTTSService` - ASR: - - supported providers: `openai_compatible`, `openai-compatible`, `siliconflow` + - supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope`, `volcengine` - fallback: `BufferedASRService` ## Notes diff --git a/engine/providers/asr/__init__.py b/engine/providers/asr/__init__.py index 2efe6a9..5e5dc29 100644 --- a/engine/providers/asr/__init__.py +++ b/engine/providers/asr/__init__.py @@ -1 +1,15 @@ """ASR providers.""" + +from providers.asr.buffered import BufferedASRService, MockASRService +from providers.asr.dashscope import DashScopeRealtimeASRService +from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService +from providers.asr.volcengine import VolcengineRealtimeASRService + +__all__ = [ + "BufferedASRService", + "MockASRService", + "DashScopeRealtimeASRService", + "OpenAICompatibleASRService", + "SiliconFlowASRService", + "VolcengineRealtimeASRService", +] diff --git a/engine/providers/asr/buffered.py b/engine/providers/asr/buffered.py index ce1a248..624963c 100644 --- a/engine/providers/asr/buffered.py +++ b/engine/providers/asr/buffered.py @@ -34,6 +34,7 @@ class BufferedASRService(BaseASRService): language: str = "en" ): super().__init__(sample_rate=sample_rate, language=language) + self.mode = "offline" self._audio_buffer: bytes = b"" self._current_text: str = "" @@ -86,6 +87,23 @@ class BufferedASRService(BaseASRService): self._current_text = "" self._audio_buffer = b"" return text + + async def get_final_transcription(self) -> str: + """Offline compatibility method used by DuplexPipeline.""" + return self.get_and_clear_text() + + def clear_buffer(self) -> None: + """Offline compatibility method used by DuplexPipeline.""" + self._audio_buffer = b"" + self._current_text = "" + + async def start_interim_transcription(self) -> None: + """No-op for plain buffered ASR.""" + return None + + async def stop_interim_transcription(self) -> None: + """No-op for plain buffered ASR.""" + return None def get_audio_buffer(self) -> bytes: """Get accumulated audio buffer.""" @@ -103,6 +121,7 @@ class MockASRService(BaseASRService): def __init__(self, sample_rate: int = 16000, language: str = "en"): super().__init__(sample_rate=sample_rate, language=language) + self.mode = "offline" self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue() self._mock_texts = [ "Hello, how are you?", @@ -145,3 +164,18 @@ class MockASRService(BaseASRService): continue except asyncio.CancelledError: break + + def clear_buffer(self) -> None: + return None + + async def get_final_transcription(self) -> str: + return "" + + def get_and_clear_text(self) -> str: + return "" + + async def start_interim_transcription(self) -> None: + return None + + async def stop_interim_transcription(self) -> None: + return None diff --git a/engine/providers/asr/dashscope.py b/engine/providers/asr/dashscope.py new file mode 100644 index 0000000..bed4ede --- /dev/null +++ b/engine/providers/asr/dashscope.py @@ -0,0 +1,388 @@ +"""DashScope realtime streaming ASR service. + +Uses Qwen-ASR-Realtime via DashScope Python SDK. +""" + +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import sys +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional + +from loguru import logger + +from providers.common.base import ASRResult, BaseASRService, ServiceState + +try: + import dashscope + from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation + + # Some SDK builds keep TranscriptionParams under qwen_omni.omni_realtime. + try: + from dashscope.audio.qwen_omni import TranscriptionParams + except ImportError: + from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams + + DASHSCOPE_SDK_AVAILABLE = True + DASHSCOPE_IMPORT_ERROR = "" +except Exception as exc: + DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}" + dashscope = None # type: ignore[assignment] + MultiModality = None # type: ignore[assignment] + OmniRealtimeConversation = None # type: ignore[assignment] + TranscriptionParams = None # type: ignore[assignment] + DASHSCOPE_SDK_AVAILABLE = False + + class OmniRealtimeCallback: # type: ignore[no-redef] + """Fallback callback base when DashScope SDK is unavailable.""" + + pass + + +class _DashScopeASRCallback(OmniRealtimeCallback): + """Bridge DashScope SDK callbacks into asyncio loop-safe handlers.""" + + def __init__(self, owner: "DashScopeRealtimeASRService", loop: asyncio.AbstractEventLoop): + super().__init__() + self._owner = owner + self._loop = loop + + def _schedule(self, fn: Callable[[], None]) -> None: + try: + self._loop.call_soon_threadsafe(fn) + except RuntimeError: + return + + def on_open(self) -> None: + self._schedule(self._owner._on_ws_open) + + def on_close(self, code: int, msg: str) -> None: + self._schedule(lambda: self._owner._on_ws_close(code, msg)) + + def on_event(self, message: Any) -> None: + self._schedule(lambda: self._owner._on_ws_event(message)) + + def on_error(self, message: Any) -> None: + self._schedule(lambda: self._owner._on_ws_error(message)) + + +class DashScopeRealtimeASRService(BaseASRService): + """Realtime streaming ASR implementation for DashScope Qwen-ASR-Realtime.""" + + DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + DEFAULT_MODEL = "qwen3-asr-flash-realtime" + DEFAULT_FINAL_TIMEOUT_MS = 800 + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + model: Optional[str] = None, + sample_rate: int = 16000, + language: str = "auto", + on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None, + ) -> None: + super().__init__(sample_rate=sample_rate, language=language) + self.mode = "streaming" + self.api_key = ( + api_key + or os.getenv("DASHSCOPE_API_KEY") + or os.getenv("ASR_API_KEY") + ) + self.api_url = api_url or os.getenv("DASHSCOPE_ASR_API_URL") or self.DEFAULT_WS_URL + self.model = model or os.getenv("DASHSCOPE_ASR_MODEL") or self.DEFAULT_MODEL + self.on_transcript = on_transcript + + self._client: Optional[Any] = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._callback: Optional[_DashScopeASRCallback] = None + + self._running = False + self._session_ready = asyncio.Event() + self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue() + self._final_queue: "asyncio.Queue[str]" = asyncio.Queue() + + self._utterance_active = False + self._audio_sent_in_utterance = False + self._last_interim_text = "" + self._last_error: Optional[str] = None + + async def connect(self) -> None: + if not DASHSCOPE_SDK_AVAILABLE: + py_exec = sys.executable + hint = f"`{py_exec} -m pip install dashscope>=1.25.6`" + detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else "" + raise RuntimeError( + f"dashscope SDK unavailable in interpreter {py_exec}; install with {hint}{detail}" + ) + if not self.api_key: + raise ValueError("DashScope ASR API key not provided. Configure agent.asr.api_key in YAML.") + + self._loop = asyncio.get_running_loop() + self._callback = _DashScopeASRCallback(owner=self, loop=self._loop) + + if dashscope is not None: + dashscope.api_key = self.api_key + + self._client = OmniRealtimeConversation( # type: ignore[misc] + model=self.model, + url=self.api_url, + callback=self._callback, + ) + await asyncio.to_thread(self._client.connect) + await self._configure_session() + + self._running = True + self.state = ServiceState.CONNECTED + logger.info( + "DashScope realtime ASR connected: model={}, sample_rate={}, language={}", + self.model, + self.sample_rate, + self.language, + ) + + async def disconnect(self) -> None: + self._running = False + self._utterance_active = False + self._audio_sent_in_utterance = False + self._drain_queue(self._final_queue) + self._drain_queue(self._transcript_queue) + self._session_ready.clear() + + if self._client is not None: + close_fn = getattr(self._client, "close", None) + if callable(close_fn): + await asyncio.to_thread(close_fn) + self._client = None + self.state = ServiceState.DISCONNECTED + logger.info("DashScope realtime ASR disconnected") + + async def begin_utterance(self) -> None: + self.clear_utterance() + self._utterance_active = True + + async def send_audio(self, audio: bytes) -> None: + if not self._client: + raise RuntimeError("DashScope ASR service not connected") + if not audio: + return + + if not self._utterance_active: + # Allow graceful fallback if caller sends before begin_utterance. + self._utterance_active = True + + audio_b64 = base64.b64encode(audio).decode("ascii") + append_fn = getattr(self._client, "append_audio", None) + if not callable(append_fn): + raise RuntimeError("DashScope ASR SDK missing append_audio method") + await asyncio.to_thread(append_fn, audio_b64) + self._audio_sent_in_utterance = True + + async def end_utterance(self) -> None: + if not self._client: + return + if not self._utterance_active or not self._audio_sent_in_utterance: + return + + commit_fn = getattr(self._client, "commit", None) + if not callable(commit_fn): + raise RuntimeError("DashScope ASR SDK missing commit method") + await asyncio.to_thread(commit_fn) + self._utterance_active = False + + async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str: + if not self._audio_sent_in_utterance: + return "" + timeout_sec = max(0.05, float(timeout_ms) / 1000.0) + try: + text = await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) + return str(text or "").strip() + except asyncio.TimeoutError: + logger.debug("DashScope ASR final timeout ({}ms), fallback to last interim", timeout_ms) + return str(self._last_interim_text or "").strip() + + def clear_utterance(self) -> None: + self._utterance_active = False + self._audio_sent_in_utterance = False + self._last_interim_text = "" + self._last_error = None + self._drain_queue(self._final_queue) + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + while self._running: + try: + result = await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1) + yield result + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + async def _configure_session(self) -> None: + if not self._client: + raise RuntimeError("DashScope ASR client is not initialized") + + text_modality: Any = "text" + if MultiModality is not None and hasattr(MultiModality, "TEXT"): + text_modality = MultiModality.TEXT + + transcription_params: Optional[Any] = None + if TranscriptionParams is not None: + try: + lang = "zh" if self.language == "auto" else self.language + transcription_params = TranscriptionParams( + language=lang, + sample_rate=self.sample_rate, + input_audio_format="pcm", + ) + except Exception as exc: + logger.debug("DashScope ASR TranscriptionParams init failed: {}", exc) + transcription_params = None + + update_attempts = [ + { + "output_modalities": [text_modality], + "enable_turn_detection": False, + "enable_input_audio_transcription": True, + "transcription_params": transcription_params, + }, + { + "output_modalities": [text_modality], + "enable_turn_detection": False, + "enable_input_audio_transcription": True, + }, + { + "output_modalities": [text_modality], + }, + ] + + update_fn = getattr(self._client, "update_session", None) + if not callable(update_fn): + raise RuntimeError("DashScope ASR SDK missing update_session method") + + last_error: Optional[Exception] = None + for params in update_attempts: + if params.get("transcription_params") is None: + params = {k: v for k, v in params.items() if k != "transcription_params"} + try: + await asyncio.to_thread(update_fn, **params) + break + except TypeError as exc: + last_error = exc + continue + except Exception as exc: + last_error = exc + continue + else: + raise RuntimeError(f"DashScope ASR session.update failed: {last_error}") + + try: + await asyncio.wait_for(self._session_ready.wait(), timeout=6.0) + except asyncio.TimeoutError: + logger.debug("DashScope ASR session ready wait timeout; continuing") + + def _on_ws_open(self) -> None: + return None + + def _on_ws_close(self, code: int, msg: str) -> None: + self._last_error = f"DashScope ASR websocket closed: {code} {msg}" + logger.debug(self._last_error) + + def _on_ws_error(self, message: Any) -> None: + self._last_error = str(message) + logger.error("DashScope ASR error: {}", self._last_error) + + def _on_ws_event(self, message: Any) -> None: + payload = self._coerce_event(message) + event_type = str(payload.get("type") or "").strip() + if not event_type: + return + + if event_type in {"session.created", "session.updated"}: + self._session_ready.set() + return + if event_type == "error" or event_type.endswith(".failed"): + err_text = self._extract_text(payload, keys=("message", "error", "details")) + self._last_error = err_text or event_type + logger.error("DashScope ASR server event error: {}", self._last_error) + return + + if event_type == "conversation.item.input_audio_transcription.text": + stash_text = self._extract_text(payload, keys=("stash", "text", "transcript")) + self._emit_transcript(stash_text, is_final=False) + return + + if event_type == "conversation.item.input_audio_transcription.completed": + final_text = self._extract_text(payload, keys=("transcript", "text", "stash")) + self._emit_transcript(final_text, is_final=True) + return + + def _emit_transcript(self, text: str, *, is_final: bool) -> None: + normalized = str(text or "").strip() + if not normalized: + return + if not is_final and normalized == self._last_interim_text: + return + if not is_final: + self._last_interim_text = normalized + + if self._loop is None: + return + try: + asyncio.run_coroutine_threadsafe( + self._publish_transcript(normalized, is_final=is_final), + self._loop, + ) + except RuntimeError: + return + + async def _publish_transcript(self, text: str, *, is_final: bool) -> None: + await self._transcript_queue.put(ASRResult(text=text, is_final=is_final)) + if is_final: + await self._final_queue.put(text) + if self.on_transcript: + try: + await self.on_transcript(text, is_final) + except Exception as exc: + logger.warning("DashScope ASR transcript callback failed: {}", exc) + + @staticmethod + def _coerce_event(message: Any) -> Dict[str, Any]: + if isinstance(message, dict): + return message + if isinstance(message, str): + try: + parsed = json.loads(message) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + return {"type": "raw", "text": message} + return {"type": "raw", "text": str(message)} + + def _extract_text(self, payload: Dict[str, Any], *, keys: tuple[str, ...]) -> str: + for key in keys: + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(value, dict): + nested = self._extract_text(value, keys=keys) + if nested: + return nested + + for value in payload.values(): + if isinstance(value, dict): + nested = self._extract_text(value, keys=keys) + if nested: + return nested + return "" + + @staticmethod + def _drain_queue(queue: "asyncio.Queue[Any]") -> None: + while True: + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break diff --git a/engine/providers/asr/openai_compatible.py b/engine/providers/asr/openai_compatible.py index 1a2083b..6d90e95 100644 --- a/engine/providers/asr/openai_compatible.py +++ b/engine/providers/asr/openai_compatible.py @@ -53,6 +53,7 @@ class OpenAICompatibleASRService(BaseASRService): model: str = "FunAudioLLM/SenseVoiceSmall", sample_rate: int = 16000, language: str = "auto", + enable_interim: bool = False, interim_interval_ms: int = 500, # How often to send interim results min_audio_for_interim_ms: int = 300, # Min audio before first interim on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None @@ -66,11 +67,13 @@ class OpenAICompatibleASRService(BaseASRService): model: ASR model name or alias sample_rate: Audio sample rate (16000 recommended) language: Language code (auto for automatic detection) + enable_interim: Whether to generate interim transcriptions in offline mode interim_interval_ms: How often to generate interim transcriptions min_audio_for_interim_ms: Minimum audio duration before first interim on_transcript: Callback for transcription results (text, is_final) """ super().__init__(sample_rate=sample_rate, language=language) + self.mode = "offline" if not AIOHTTP_AVAILABLE: raise RuntimeError("aiohttp is required for OpenAICompatibleASRService") @@ -79,6 +82,7 @@ class OpenAICompatibleASRService(BaseASRService): raw_api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL self.api_url = self._resolve_transcriptions_endpoint(raw_api_url) self.model = self.MODELS.get(model.lower(), model) + self.enable_interim = bool(enable_interim) self.interim_interval_ms = interim_interval_ms self.min_audio_for_interim_ms = min_audio_for_interim_ms self.on_transcript = on_transcript @@ -180,6 +184,9 @@ class OpenAICompatibleASRService(BaseASRService): if not self._session: logger.warning("ASR session not connected") return None + + if not is_final and not self.enable_interim: + return None # Check minimum audio duration audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000 @@ -309,6 +316,9 @@ class OpenAICompatibleASRService(BaseASRService): This periodically transcribes buffered audio for real-time feedback to the user. """ + if not self.enable_interim: + return + if self._interim_task and not self._interim_task.done(): return diff --git a/engine/providers/asr/volcengine.py b/engine/providers/asr/volcengine.py new file mode 100644 index 0000000..1f7c18e --- /dev/null +++ b/engine/providers/asr/volcengine.py @@ -0,0 +1,666 @@ +"""Volcengine realtime ASR service. + +Supports both: +- Volcengine Edge Gateway realtime transcription websocket, and +- Volcengine BigASR Seed websocket at openspeech.bytedance.com/api/v3/sauc/bigmodel. +""" + +from __future__ import annotations + +import asyncio +import base64 +import gzip +import json +import os +import uuid +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse + +import aiohttp +from loguru import logger + +from providers.common.base import ASRResult, BaseASRService, ServiceState + +VolcengineASRProtocol = Literal["gateway", "seed"] + + +class VolcengineRealtimeASRService(BaseASRService): + """Realtime streaming ASR backed by Volcengine websocket APIs.""" + + DEFAULT_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel" + DEFAULT_GATEWAY_WS_URL = "wss://ai-gateway.vei.volces.com/v1/realtime" + DEFAULT_MODEL = "bigmodel" + DEFAULT_FINAL_TIMEOUT_MS = 1200 + DEFAULT_SEED_RESOURCE_ID = "volc.bigasr.sauc.duration" + _SEED_FRAME_MS = 100 + _SEED_PROTOCOL_VERSION = 0b0001 + _SEED_FULL_CLIENT_REQUEST = 0b0001 + _SEED_AUDIO_ONLY_REQUEST = 0b0010 + _SEED_FULL_SERVER_RESPONSE = 0b1001 + _SEED_SERVER_ACK = 0b1011 + _SEED_SERVER_ERROR_RESPONSE = 0b1111 + _SEED_NO_SEQUENCE = 0b0000 + _SEED_POS_SEQUENCE = 0b0001 + _SEED_NEG_WITH_SEQUENCE = 0b0011 + _SEED_NO_SERIALIZATION = 0b0000 + _SEED_JSON = 0b0001 + _SEED_NO_COMPRESSION = 0b0000 + _SEED_GZIP = 0b0001 + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + model: Optional[str] = None, + sample_rate: int = 16000, + language: str = "auto", + app_id: Optional[str] = None, + resource_id: Optional[str] = None, + uid: Optional[str] = None, + request_params: Optional[Dict[str, Any]] = None, + on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None, + ) -> None: + super().__init__(sample_rate=sample_rate, language=language) + self.mode = "streaming" + self.api_key = api_key or os.getenv("VOLCENGINE_ASR_API_KEY") or os.getenv("ASR_API_KEY") + self.model = str(model or os.getenv("VOLCENGINE_ASR_MODEL") or self.DEFAULT_MODEL).strip() + raw_api_url = api_url or os.getenv("VOLCENGINE_ASR_API_URL") or self.DEFAULT_WS_URL + self.protocol = self._detect_protocol(raw_api_url) + self.api_url = self._resolve_api_url(raw_api_url, self.model, self.protocol) + self.app_id = app_id or os.getenv("VOLCENGINE_ASR_APP_ID") or os.getenv("ASR_APP_ID") + self.resource_id = ( + resource_id + or os.getenv("VOLCENGINE_ASR_RESOURCE_ID") + or (self.DEFAULT_SEED_RESOURCE_ID if self.protocol == "seed" else None) + ) + self.uid = uid or os.getenv("VOLCENGINE_ASR_UID") + self.request_params = self._load_request_params(request_params) + self.on_transcript = on_transcript + + self._session: Optional[aiohttp.ClientSession] = None + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._reader_task: Optional[asyncio.Task[None]] = None + + self._running = False + self._session_ready = asyncio.Event() + self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue() + self._final_queue: "asyncio.Queue[str]" = asyncio.Queue() + + self._utterance_active = False + self._audio_sent_in_utterance = False + self._last_interim_text = "" + self._last_error: Optional[str] = None + + self._seed_audio_buffer = bytearray() + self._seed_sequence = 1 + self._seed_request_id: Optional[str] = None + self._seed_frame_bytes = max(2, int((self.sample_rate * self._SEED_FRAME_MS / 1000) * 2)) + + @classmethod + def _detect_protocol(cls, api_url: str) -> VolcengineASRProtocol: + parsed = urlparse(str(api_url or "").strip()) + host = parsed.netloc.lower() + path = parsed.path.lower() + if "openspeech.bytedance.com" in host and "/api/v3/sauc/bigmodel" in path: + return "seed" + return "gateway" + + @classmethod + def _resolve_api_url(cls, api_url: str, model: str, protocol: VolcengineASRProtocol) -> str: + raw = str(api_url or "").strip() + if not raw: + raw = cls.DEFAULT_WS_URL if protocol == "seed" else cls.DEFAULT_GATEWAY_WS_URL + if protocol != "gateway": + return raw + + parsed = urlparse(raw) + query = dict(parse_qsl(parsed.query, keep_blank_values=True)) + query.setdefault("model", model or cls.DEFAULT_MODEL) + return urlunparse(parsed._replace(query=urlencode(query))) + + @staticmethod + def _load_request_params(request_params: Optional[Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(request_params, dict): + return dict(request_params) + raw = os.getenv("VOLCENGINE_ASR_REQUEST_PARAMS_JSON", "").strip() + if not raw: + return {} + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + logger.warning("Ignoring invalid VOLCENGINE_ASR_REQUEST_PARAMS_JSON") + return {} + if isinstance(parsed, dict): + return parsed + return {} + + async def connect(self) -> None: + if not self.api_key: + raise ValueError("Volcengine ASR API key not provided. Configure agent.asr.api_key in YAML.") + + timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15) + self._session = aiohttp.ClientSession(timeout=timeout) + self._running = True + + if self.protocol == "gateway": + await self._connect_gateway() + logger.info( + "Volcengine gateway ASR connected: model={}, sample_rate={}, url={}", + self.model, + self.sample_rate, + self.api_url, + ) + else: + if not self.app_id: + raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.") + logger.info( + "Volcengine BigASR Seed ready: model={}, sample_rate={}, resource_id={}", + self.model, + self.sample_rate, + self.resource_id, + ) + + self.state = ServiceState.CONNECTED + + async def disconnect(self) -> None: + self._running = False + self._utterance_active = False + self._audio_sent_in_utterance = False + self._session_ready.clear() + self._seed_audio_buffer = bytearray() + self._drain_queue(self._final_queue) + self._drain_queue(self._transcript_queue) + + await self._close_ws() + + if self._session is not None: + await self._session.close() + self._session = None + + self.state = ServiceState.DISCONNECTED + logger.info("Volcengine ASR disconnected") + + async def begin_utterance(self) -> None: + self.clear_utterance() + if self.protocol == "seed": + await self._open_seed_stream() + self._utterance_active = True + + async def send_audio(self, audio: bytes) -> None: + if not audio: + return + + if self.protocol == "seed": + await self._send_seed_audio(audio) + return + + if not self._ws: + raise RuntimeError("Volcengine ASR websocket is not connected") + if not self._utterance_active: + self._utterance_active = True + + await self._ws.send_json( + { + "type": "input_audio_buffer.append", + "audio": base64.b64encode(audio).decode("ascii"), + } + ) + self._audio_sent_in_utterance = True + + async def end_utterance(self) -> None: + if not self._utterance_active: + return + + if self.protocol == "seed": + await self._end_seed_utterance() + return + + if not self._ws or not self._audio_sent_in_utterance: + return + await self._ws.send_json({"type": "input_audio_buffer.commit"}) + self._utterance_active = False + + async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str: + if not self._audio_sent_in_utterance: + return "" + + timeout_sec = max(0.05, float(timeout_ms) / 1000.0) + try: + return str(await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) or "").strip() + except asyncio.TimeoutError: + logger.debug("Volcengine ASR final timeout ({}ms), fallback to last interim", timeout_ms) + return str(self._last_interim_text or "").strip() + finally: + if self.protocol == "seed": + await self._close_ws() + + def clear_utterance(self) -> None: + self._utterance_active = False + self._audio_sent_in_utterance = False + self._last_interim_text = "" + self._last_error = None + self._seed_audio_buffer = bytearray() + self._seed_sequence = 1 + self._seed_request_id = None + self._drain_queue(self._final_queue) + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + while self._running: + try: + yield await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + async def _connect_gateway(self) -> None: + assert self._session is not None + headers = {"Authorization": f"Bearer {self.api_key}"} + if self.resource_id: + headers["X-Api-Resource-Id"] = self.resource_id + + self._ws = await self._session.ws_connect(self.api_url, headers=headers, heartbeat=20) + self._reader_task = asyncio.create_task(self._reader_loop()) + await self._configure_gateway_session() + + async def _configure_gateway_session(self) -> None: + if not self._ws: + raise RuntimeError("Volcengine ASR websocket is not initialized") + + session_payload: Dict[str, Any] = { + "input_audio_format": "pcm", + "input_audio_codec": "raw", + "input_audio_sample_rate": self.sample_rate, + "input_audio_bits": 16, + "input_audio_channel": 1, + "result_type": 0, + "input_audio_transcription": { + "model": self.model, + }, + } + + await self._ws.send_json( + { + "type": "transcription_session.update", + "session": session_payload, + } + ) + + try: + await asyncio.wait_for(self._session_ready.wait(), timeout=8.0) + except asyncio.TimeoutError as exc: + raise RuntimeError("Volcengine ASR session update timeout") from exc + + async def _open_seed_stream(self) -> None: + if not self._session: + raise RuntimeError("Volcengine ASR session is not initialized") + + await self._close_ws() + self._seed_request_id = uuid.uuid4().hex + headers = self._build_seed_headers(self._seed_request_id) + self._ws = await self._session.ws_connect( + self.api_url, + headers=headers, + heartbeat=20, + max_msg_size=1_000_000_000, + ) + self._reader_task = asyncio.create_task(self._reader_loop()) + await self._ws.send_bytes(self._build_seed_start_request()) + + async def _send_seed_audio(self, audio: bytes) -> None: + if not self._utterance_active: + await self.begin_utterance() + if not self._ws: + raise RuntimeError("Volcengine BigASR websocket is not connected") + + self._seed_audio_buffer.extend(audio) + while len(self._seed_audio_buffer) >= self._seed_frame_bytes: + chunk = bytes(self._seed_audio_buffer[: self._seed_frame_bytes]) + del self._seed_audio_buffer[: self._seed_frame_bytes] + self._seed_sequence += 1 + await self._ws.send_bytes(self._build_seed_audio_request(chunk, sequence=self._seed_sequence)) + self._audio_sent_in_utterance = True + + async def _end_seed_utterance(self) -> None: + if not self._ws: + return + if not self._audio_sent_in_utterance and not self._seed_audio_buffer: + self._utterance_active = False + return + + final_chunk = bytes(self._seed_audio_buffer) + self._seed_audio_buffer = bytearray() + self._seed_sequence += 1 + await self._ws.send_bytes( + self._build_seed_audio_request(final_chunk, sequence=-self._seed_sequence, is_last=True) + ) + self._audio_sent_in_utterance = True + self._utterance_active = False + + async def _close_ws(self) -> None: + reader_task = self._reader_task + ws = self._ws + self._reader_task = None + self._ws = None + + if reader_task: + reader_task.cancel() + try: + await reader_task + except asyncio.CancelledError: + pass + + if ws is not None: + await ws.close() + + async def _reader_loop(self) -> None: + ws = self._ws + if ws is None: + return + + try: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if self.protocol == "gateway": + self._handle_gateway_event(msg.data) + else: + self._handle_seed_text(msg.data) + continue + if msg.type == aiohttp.WSMsgType.BINARY: + if self.protocol == "seed": + self._handle_seed_binary(msg.data) + continue + if msg.type == aiohttp.WSMsgType.ERROR: + self._last_error = str(ws.exception()) + logger.error("Volcengine ASR websocket error: {}", self._last_error) + break + if msg.type in {aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE}: + break + except asyncio.CancelledError: + raise + except Exception as exc: + self._last_error = str(exc) + logger.error("Volcengine ASR reader loop failed: {}", exc) + finally: + if self._ws is ws: + self._ws = None + + def _handle_gateway_event(self, message: str) -> None: + payload = self._coerce_event(message) + event_type = str(payload.get("type") or "").strip() + if not event_type: + return + + if event_type in {"transcription_session.created", "transcription_session.updated"}: + self._session_ready.set() + return + + if event_type == "error": + self._last_error = self._extract_text(payload, ("message", "error")) + logger.error("Volcengine ASR server error: {}", self._last_error or "unknown") + return + + if event_type.endswith(".failed"): + self._last_error = self._extract_text(payload, ("message", "error", "transcript")) + logger.error("Volcengine ASR failed event: {}", self._last_error or event_type) + return + + if event_type == "conversation.item.input_audio_transcription.result": + transcript = self._extract_text(payload, ("transcript", "result")) + self._emit_transcript_sync(transcript, is_final=False) + return + + if event_type == "conversation.item.input_audio_transcription.delta": + transcript = self._extract_text(payload, ("delta",)) + self._emit_transcript_sync(transcript, is_final=False) + return + + if event_type == "conversation.item.input_audio_transcription.completed": + transcript = self._extract_text(payload, ("transcript", "result")) + self._emit_transcript_sync(transcript, is_final=True) + + def _handle_seed_text(self, message: str) -> None: + payload = self._coerce_event(message) + if payload.get("type") == "error": + self._last_error = self._extract_text(payload, ("message", "error")) + logger.error("Volcengine BigASR error: {}", self._last_error or "unknown") + + def _handle_seed_binary(self, message: bytes) -> None: + payload = self._parse_seed_response(message) + if payload.get("code"): + self._last_error = self._extract_text(payload, ("payload_msg",)) + logger.error("Volcengine BigASR server error: {}", self._last_error or payload["code"]) + return + + body = payload.get("payload_msg") + if not isinstance(body, dict): + return + result = body.get("result") + if not isinstance(result, dict): + return + + text = str(result.get("text") or "").strip() + if not text: + return + + utterances = result.get("utterances") + if not isinstance(utterances, list) or not utterances: + return + first_utterance = utterances[0] if isinstance(utterances[0], dict) else {} + is_final = self._coerce_bool(first_utterance.get("definite")) is True + self._emit_transcript_sync(text, is_final=is_final) + + def _emit_transcript_sync(self, text: str, *, is_final: bool) -> None: + cleaned = str(text or "").strip() + if not cleaned: + return + + if not is_final: + self._last_interim_text = cleaned + else: + self._last_interim_text = "" + + result = ASRResult(text=cleaned, is_final=is_final) + try: + self._transcript_queue.put_nowait(result) + except asyncio.QueueFull: + logger.debug("Volcengine ASR transcript queue full; dropping transcript") + + if is_final: + try: + self._final_queue.put_nowait(cleaned) + except asyncio.QueueFull: + logger.debug("Volcengine ASR final queue full; dropping transcript") + + if self.on_transcript: + asyncio.create_task(self.on_transcript(cleaned, is_final)) + + def _build_seed_headers(self, request_id: str) -> Dict[str, str]: + if not self.app_id: + raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.") + if not self.api_key: + raise ValueError("Volcengine ASR api_key not provided. Configure agent.asr.api_key in YAML.") + + return { + "X-Api-App-Key": str(self.app_id), + "X-Api-Access-Key": str(self.api_key), + "X-Api-Resource-Id": str(self.resource_id or self.DEFAULT_SEED_RESOURCE_ID), + "X-Api-Request-Id": str(request_id), + } + + def _build_seed_start_payload(self) -> Dict[str, Any]: + user_payload: Dict[str, Any] = {"uid": str(self.uid or self._seed_request_id or self.app_id or uuid.uuid4().hex)} + audio_payload: Dict[str, Any] = { + "format": "pcm", + "rate": self.sample_rate, + "bits": 16, + "channels": 1, + "codec": "raw", + } + if self.language and self.language != "auto": + audio_payload["language"] = self.language + + request_payload: Dict[str, Any] = { + "model_name": self.model or self.DEFAULT_MODEL, + "enable_itn": False, + "enable_punc": True, + "enable_ddc": False, + "show_utterance": True, + "result_type": "single", + "vad_segment_duration": 3000, + "end_window_size": 500, + "force_to_speech_time": 1000, + } + + extra = dict(self.request_params) + user_payload.update(self._as_dict(extra.pop("user", None))) + audio_payload.update(self._as_dict(extra.pop("audio", None))) + request_payload.update(self._as_dict(extra.pop("request", None))) + request_payload.update(extra) + + return { + "user": user_payload, + "audio": audio_payload, + "request": request_payload, + } + + def _build_seed_start_request(self) -> bytes: + payload = gzip.compress(json.dumps(self._build_seed_start_payload()).encode("utf-8")) + frame = bytearray( + self._build_seed_header( + message_type=self._SEED_FULL_CLIENT_REQUEST, + message_type_specific_flags=self._SEED_POS_SEQUENCE, + ) + ) + frame.extend((1).to_bytes(4, "big", signed=True)) + frame.extend(len(payload).to_bytes(4, "big")) + frame.extend(payload) + return bytes(frame) + + def _build_seed_audio_request(self, chunk: bytes, *, sequence: int, is_last: bool = False) -> bytes: + payload = gzip.compress(chunk) + frame = bytearray( + self._build_seed_header( + message_type=self._SEED_AUDIO_ONLY_REQUEST, + message_type_specific_flags=self._SEED_NEG_WITH_SEQUENCE if is_last else self._SEED_POS_SEQUENCE, + ) + ) + frame.extend(int(sequence).to_bytes(4, "big", signed=True)) + frame.extend(len(payload).to_bytes(4, "big")) + frame.extend(payload) + return bytes(frame) + + @classmethod + def _build_seed_header( + cls, + *, + message_type: int, + message_type_specific_flags: int, + serial_method: int = _SEED_JSON, + compression_type: int = _SEED_GZIP, + reserved_data: int = 0x00, + ) -> bytes: + header = bytearray() + header.append((cls._SEED_PROTOCOL_VERSION << 4) | 0b0001) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + return bytes(header) + + @classmethod + def _parse_seed_response(cls, response: bytes) -> Dict[str, Any]: + header_size = response[0] & 0x0F + message_type = response[1] >> 4 + message_type_specific_flags = response[1] & 0x0F + serialization_method = response[2] >> 4 + compression_type = response[2] & 0x0F + payload = response[header_size * 4 :] + + result: Dict[str, Any] = {"is_last_package": False} + payload_message: Any = None + + if message_type_specific_flags & 0x01: + result["payload_sequence"] = int.from_bytes(payload[:4], "big", signed=True) + payload = payload[4:] + + if message_type_specific_flags & 0x02: + result["is_last_package"] = True + + if message_type == cls._SEED_FULL_SERVER_RESPONSE: + result["payload_size"] = int.from_bytes(payload[:4], "big", signed=True) + payload_message = payload[4:] + elif message_type == cls._SEED_SERVER_ACK: + result["seq"] = int.from_bytes(payload[:4], "big", signed=True) + if len(payload) >= 8: + result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False) + payload_message = payload[8:] + elif message_type == cls._SEED_SERVER_ERROR_RESPONSE: + result["code"] = int.from_bytes(payload[:4], "big", signed=False) + result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False) + payload_message = payload[8:] + + if payload_message is None: + return result + if compression_type == cls._SEED_GZIP: + payload_message = gzip.decompress(payload_message) + if serialization_method == cls._SEED_JSON: + payload_message = json.loads(payload_message.decode("utf-8")) + elif serialization_method != cls._SEED_NO_SERIALIZATION: + payload_message = payload_message.decode("utf-8") + + result["payload_msg"] = payload_message + return result + + @staticmethod + def _coerce_event(message: Any) -> Dict[str, Any]: + if isinstance(message, dict): + return message + if isinstance(message, str): + try: + loaded = json.loads(message) + if isinstance(loaded, dict): + return loaded + except json.JSONDecodeError: + return {"type": "raw", "message": message} + return {"type": "raw", "message": str(message)} + + @staticmethod + def _extract_text(payload: Dict[str, Any], keys: tuple[str, ...]) -> str: + for key in keys: + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(value, dict): + for nested_key in ("message", "text", "transcript", "result", "delta"): + nested = value.get(nested_key) + if isinstance(nested, str) and nested.strip(): + return nested.strip() + return "" + + @staticmethod + def _coerce_bool(value: Any) -> Optional[bool]: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + return None + + @staticmethod + def _as_dict(value: Any) -> Dict[str, Any]: + if isinstance(value, dict): + return dict(value) + return {} + + @staticmethod + def _drain_queue(queue: "asyncio.Queue[Any]") -> None: + while True: + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break diff --git a/engine/providers/factory/default.py b/engine/providers/factory/default.py index 4294d3c..de72af6 100644 --- a/engine/providers/factory/default.py +++ b/engine/providers/factory/default.py @@ -16,13 +16,18 @@ from runtime.ports import ( TTSServiceSpec, ) from providers.asr.buffered import BufferedASRService +from providers.asr.dashscope import DashScopeRealtimeASRService +from providers.asr.volcengine import VolcengineRealtimeASRService from providers.tts.dashscope import DashScopeTTSService from providers.llm.openai import MockLLMService, OpenAILLMService from providers.asr.openai_compatible import OpenAICompatibleASRService from providers.tts.openai_compatible import OpenAICompatibleTTSService from providers.tts.mock import MockTTSService +from providers.tts.volcengine import VolcengineTTSService _OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} +_DASHSCOPE_PROVIDERS = {"dashscope"} +_VOLCENGINE_PROVIDERS = {"volcengine"} _SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS} @@ -31,8 +36,14 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): _DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" _DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime" + _DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + _DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime" _DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B" _DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" + _DEFAULT_VOLCENGINE_TTS_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional" + _DEFAULT_VOLCENGINE_TTS_RESOURCE_ID = "seed-tts-2.0" + _DEFAULT_VOLCENGINE_ASR_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel" + _DEFAULT_VOLCENGINE_ASR_MODEL = "bigmodel" @staticmethod def _normalize_provider(provider: Any) -> str: @@ -77,6 +88,19 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): speed=spec.speed, ) + if provider in _VOLCENGINE_PROVIDERS and spec.api_key: + return VolcengineTTSService( + api_key=spec.api_key, + api_url=spec.api_url or self._DEFAULT_VOLCENGINE_TTS_URL, + voice=spec.voice, + model=spec.model, + app_id=spec.app_id, + resource_id=spec.resource_id or self._DEFAULT_VOLCENGINE_TTS_RESOURCE_ID, + uid=spec.uid, + sample_rate=spec.sample_rate, + speed=spec.speed, + ) + if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleTTSService( api_key=spec.api_key, @@ -96,6 +120,30 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort: provider = self._normalize_provider(spec.provider) + if provider in _DASHSCOPE_PROVIDERS and spec.api_key: + return DashScopeRealtimeASRService( + api_key=spec.api_key, + api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL, + model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL, + sample_rate=spec.sample_rate, + language=spec.language, + on_transcript=spec.on_transcript, + ) + + if provider in _VOLCENGINE_PROVIDERS and spec.api_key: + return VolcengineRealtimeASRService( + api_key=spec.api_key, + api_url=spec.api_url or self._DEFAULT_VOLCENGINE_ASR_REALTIME_URL, + model=spec.model or self._DEFAULT_VOLCENGINE_ASR_MODEL, + sample_rate=spec.sample_rate, + language=spec.language, + app_id=spec.app_id, + resource_id=spec.resource_id, + uid=spec.uid, + request_params=spec.request_params, + on_transcript=spec.on_transcript, + ) + if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key: return OpenAICompatibleASRService( api_key=spec.api_key, @@ -103,6 +151,7 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL, sample_rate=spec.sample_rate, language=spec.language, + enable_interim=spec.enable_interim, interim_interval_ms=spec.interim_interval_ms, min_audio_for_interim_ms=spec.min_audio_for_interim_ms, on_transcript=spec.on_transcript, diff --git a/engine/providers/tts/__init__.py b/engine/providers/tts/__init__.py index 531ecfa..b2b237a 100644 --- a/engine/providers/tts/__init__.py +++ b/engine/providers/tts/__init__.py @@ -1 +1,5 @@ """TTS providers.""" + +from providers.tts.volcengine import VolcengineTTSService + +__all__ = ["VolcengineTTSService"] diff --git a/engine/providers/tts/volcengine.py b/engine/providers/tts/volcengine.py new file mode 100644 index 0000000..d7502a1 --- /dev/null +++ b/engine/providers/tts/volcengine.py @@ -0,0 +1,219 @@ +"""Volcengine TTS service. + +Uses Volcengine's unidirectional HTTP streaming TTS API and adapts streamed +base64 audio chunks into engine-native ``TTSChunk`` events. +""" + +from __future__ import annotations + +import asyncio +import base64 +import codecs +import json +import os +import uuid +from typing import Any, AsyncIterator, Optional + +import aiohttp +from loguru import logger + +from providers.common.base import BaseTTSService, ServiceState, TTSChunk + + +class VolcengineTTSService(BaseTTSService): + """Streaming TTS adapter for Volcengine's HTTP v3 API.""" + + DEFAULT_API_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional" + DEFAULT_RESOURCE_ID = "seed-tts-2.0" + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + voice: str = "zh_female_shuangkuaisisi_moon_bigtts", + model: Optional[str] = None, + app_id: Optional[str] = None, + resource_id: Optional[str] = None, + uid: Optional[str] = None, + sample_rate: int = 16000, + speed: float = 1.0, + ) -> None: + super().__init__(voice=voice, sample_rate=sample_rate, speed=speed) + self.api_key = api_key or os.getenv("VOLCENGINE_TTS_API_KEY") or os.getenv("TTS_API_KEY") + self.api_url = api_url or os.getenv("VOLCENGINE_TTS_API_URL") or self.DEFAULT_API_URL + self.model = str(model or os.getenv("VOLCENGINE_TTS_MODEL") or "").strip() or None + self.app_id = app_id or os.getenv("VOLCENGINE_TTS_APP_ID") or os.getenv("TTS_APP_ID") + self.resource_id = resource_id or os.getenv("VOLCENGINE_TTS_RESOURCE_ID") or self.DEFAULT_RESOURCE_ID + self.uid = uid or os.getenv("VOLCENGINE_TTS_UID") + + self._session: Optional[aiohttp.ClientSession] = None + self._cancel_event = asyncio.Event() + self._synthesis_lock = asyncio.Lock() + self._pending_audio: list[bytes] = [] + + async def connect(self) -> None: + if not self.api_key: + raise ValueError("Volcengine TTS API key not provided. Configure agent.tts.api_key in YAML.") + if not self.app_id: + raise ValueError("Volcengine TTS app_id not provided. Configure agent.tts.app_id in YAML.") + + timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15) + self._session = aiohttp.ClientSession(timeout=timeout) + self.state = ServiceState.CONNECTED + logger.info( + "Volcengine TTS service ready: speaker={}, sample_rate={}, resource_id={}", + self.voice, + self.sample_rate, + self.resource_id, + ) + + async def disconnect(self) -> None: + self._cancel_event.set() + if self._session is not None: + await self._session.close() + self._session = None + self.state = ServiceState.DISCONNECTED + logger.info("Volcengine TTS service disconnected") + + async def synthesize(self, text: str) -> bytes: + audio = b"" + async for chunk in self.synthesize_stream(text): + audio += chunk.audio + return audio + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + if not self._session: + raise RuntimeError("Volcengine TTS service not connected") + if not text.strip(): + return + + async with self._synthesis_lock: + self._cancel_event.clear() + + headers = { + "Content-Type": "application/json", + "X-Api-App-Key": str(self.app_id), + "X-Api-Access-Key": str(self.api_key), + "X-Api-Resource-Id": str(self.resource_id), + "X-Api-Request-Id": str(uuid.uuid4()), + } + payload = { + "user": { + "uid": str(self.uid or self.app_id), + }, + "req_params": { + "text": text, + "speaker": self.voice, + "audio_params": { + "format": "pcm", + "sample_rate": self.sample_rate, + "speech_rate": self._speech_rate_percent(self.speed), + }, + }, + } + if self.model: + payload["req_params"]["model"] = self.model + + chunk_size = max(1, self.sample_rate * 2 // 10) + audio_buffer = b"" + pending_chunk: Optional[bytes] = None + + try: + async with self._session.post(self.api_url, headers=headers, json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Volcengine TTS error {response.status}: {error_text}") + + async for audio_bytes in self._iter_audio_bytes(response): + if self._cancel_event.is_set(): + logger.info("Volcengine TTS synthesis cancelled") + return + + audio_buffer += audio_bytes + while len(audio_buffer) >= chunk_size: + emitted = audio_buffer[:chunk_size] + audio_buffer = audio_buffer[chunk_size:] + if pending_chunk is not None: + yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False) + pending_chunk = emitted + + if self._cancel_event.is_set(): + return + + if pending_chunk is not None: + if audio_buffer: + yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False) + pending_chunk = None + else: + yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True) + pending_chunk = None + + if audio_buffer: + yield TTSChunk(audio=audio_buffer, sample_rate=self.sample_rate, is_final=True) + + except asyncio.CancelledError: + logger.info("Volcengine TTS synthesis cancelled via asyncio") + raise + except Exception as exc: + logger.error("Volcengine TTS synthesis error: {}", exc) + raise + + async def cancel(self) -> None: + self._cancel_event.set() + + async def _iter_audio_bytes(self, response: aiohttp.ClientResponse) -> AsyncIterator[bytes]: + decoder = json.JSONDecoder() + utf8_decoder = codecs.getincrementaldecoder("utf-8")() + text_buffer = "" + self._pending_audio.clear() + + async for raw_chunk in response.content.iter_any(): + text_buffer += utf8_decoder.decode(raw_chunk) + text_buffer = self._yield_audio_payloads(decoder, text_buffer) + while self._pending_audio: + yield self._pending_audio.pop(0) + + text_buffer += utf8_decoder.decode(b"", final=True) + text_buffer = self._yield_audio_payloads(decoder, text_buffer) + while self._pending_audio: + yield self._pending_audio.pop(0) + + def _yield_audio_payloads(self, decoder: json.JSONDecoder, text_buffer: str) -> str: + while True: + stripped = text_buffer.lstrip() + if not stripped: + return "" + if len(stripped) != len(text_buffer): + text_buffer = stripped + + try: + payload, idx = decoder.raw_decode(text_buffer) + except json.JSONDecodeError: + return text_buffer + + text_buffer = text_buffer[idx:] + audio = self._extract_audio_bytes(payload) + if audio: + self._pending_audio.append(audio) + + def _extract_audio_bytes(self, payload: Any) -> bytes: + if not isinstance(payload, dict): + return b"" + + code = payload.get("code") + if code not in (None, 0, 20000000): + message = str(payload.get("message") or "unknown error") + raise RuntimeError(f"Volcengine TTS stream error {code}: {message}") + + encoded = payload.get("data") + if isinstance(encoded, str) and encoded.strip(): + try: + return base64.b64decode(encoded) + except Exception as exc: + logger.warning("Failed to decode Volcengine TTS audio chunk: {}", exc) + return b"" + + @staticmethod + def _speech_rate_percent(speed: float) -> int: + clamped = max(0.5, min(2.0, float(speed or 1.0))) + return int(round((clamped - 1.0) * 100)) diff --git a/engine/runtime/pipeline/duplex.py b/engine/runtime/pipeline/duplex.py index aacd0c7..dcf198f 100644 --- a/engine/runtime/pipeline/duplex.py +++ b/engine/runtime/pipeline/duplex.py @@ -30,11 +30,14 @@ from providers.factory.default import DefaultRealtimeServiceFactory from runtime.conversation import ConversationManager, ConversationState from runtime.events import get_event_bus from runtime.ports import ( + ASRMode, ASRPort, ASRServiceSpec, LLMPort, LLMServiceSpec, + OfflineASRPort, RealtimeServiceFactory, + StreamingASRPort, TTSPort, TTSServiceSpec, ) @@ -77,6 +80,7 @@ class DuplexPipeline: _ASR_DELTA_THROTTLE_MS = 500 _LLM_DELTA_THROTTLE_MS = 80 _ASR_CAPTURE_MAX_MS = 15000 + _ASR_STREAM_FINAL_TIMEOUT_MS = 800 _OPENER_PRE_ROLL_MS = 180 _DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = { "current_time": { @@ -317,6 +321,10 @@ class DuplexPipeline: self.llm_service = llm_service self.tts_service = tts_service self.asr_service = asr_service # Will be initialized in start() + self._asr_mode: ASRMode = self._resolve_asr_mode( + settings.asr_provider, + getattr(asr_service, "mode", None), + ) self._service_factory = service_factory or DefaultRealtimeServiceFactory() self._knowledge_searcher = knowledge_searcher self._tool_resource_resolver = tool_resource_resolver @@ -324,6 +332,7 @@ class DuplexPipeline: # Track last sent transcript to avoid duplicates self._last_sent_transcript = "" + self._latest_asr_interim_text = "" self._pending_transcript_delta: str = "" self._last_transcript_delta_emit_ms: float = 0.0 @@ -588,7 +597,9 @@ class DuplexPipeline: }, "asr": { "provider": asr_provider, + "mode": self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")), "model": str(self._runtime_asr.get("model") or settings.asr_model or ""), + "enableInterim": self._asr_interim_enabled(), "interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms), "minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms), }, @@ -782,11 +793,44 @@ class DuplexPipeline: return False return None + @staticmethod + def _coerce_json_object(value: Any) -> Optional[Dict[str, Any]]: + if isinstance(value, dict): + return dict(value) + if isinstance(value, str): + raw = value.strip() + if not raw: + return None + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + logger.warning("Ignoring invalid JSON object config: {}", raw[:120]) + return None + if isinstance(parsed, dict): + return parsed + return None + @staticmethod def _is_dashscope_tts_provider(provider: Any) -> bool: normalized = str(provider or "").strip().lower() return normalized == "dashscope" + @staticmethod + def _resolve_asr_mode(provider: Any, raw_mode: Any = None) -> ASRMode: + normalized_mode = str(raw_mode or "").strip().lower() + if normalized_mode in {"offline", "streaming"}: + return normalized_mode # type: ignore[return-value] + normalized_provider = str(provider or "").strip().lower() + if normalized_provider in {"dashscope", "volcengine"}: + return "streaming" + return "offline" + + def _offline_asr(self) -> OfflineASRPort: + return self.asr_service # type: ignore[return-value] + + def _streaming_asr(self) -> StreamingASRPort: + return self.asr_service # type: ignore[return-value] + @staticmethod def _default_llm_base_url(provider: Any) -> Optional[str]: normalized = str(provider or "").strip().lower() @@ -839,6 +883,20 @@ class DuplexPipeline: return self._runtime_barge_in_min_duration_ms return self._barge_in_min_duration_ms + def _asr_interim_enabled(self) -> bool: + current_mode = self._asr_mode + if not self.asr_service: + current_mode = self._resolve_asr_mode( + self._runtime_asr.get("provider") or settings.asr_provider, + self._runtime_asr.get("mode"), + ) + if current_mode != "offline": + return True + enabled = self._coerce_bool(self._runtime_asr.get("enableInterim")) + if enabled is not None: + return enabled + return bool(settings.asr_enable_interim) + def _barge_in_silence_tolerance_frames(self) -> int: """Convert silence tolerance from ms to frame count using current chunk size.""" chunk_ms = max(1, settings.chunk_size_ms) @@ -922,6 +980,10 @@ class DuplexPipeline: tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url tts_voice = self._runtime_tts.get("voice") or settings.tts_voice tts_model = self._runtime_tts.get("model") or settings.tts_model + tts_app_id = self._runtime_tts.get("appId") or settings.tts_app_id + tts_resource_id = self._runtime_tts.get("resourceId") or settings.tts_resource_id + tts_cluster = self._runtime_tts.get("cluster") or settings.tts_cluster + tts_uid = self._runtime_tts.get("uid") or settings.tts_uid tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) tts_mode = self._resolved_dashscope_tts_mode() runtime_mode = str(self._runtime_tts.get("mode") or "").strip() @@ -937,6 +999,10 @@ class DuplexPipeline: api_url=str(tts_api_url).strip() if tts_api_url else None, voice=str(tts_voice), model=str(tts_model).strip() if tts_model else None, + app_id=str(tts_app_id).strip() if tts_app_id else None, + resource_id=str(tts_resource_id).strip() if tts_resource_id else None, + cluster=str(tts_cluster).strip() if tts_cluster else None, + uid=str(tts_uid).strip() if tts_uid else None, sample_rate=settings.sample_rate, speed=tts_speed, mode=str(tts_mode), @@ -965,26 +1031,48 @@ class DuplexPipeline: asr_api_key = self._runtime_asr.get("apiKey") asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url asr_model = self._runtime_asr.get("model") or settings.asr_model + asr_app_id = self._runtime_asr.get("appId") or settings.asr_app_id + asr_resource_id = self._runtime_asr.get("resourceId") or settings.asr_resource_id + asr_cluster = self._runtime_asr.get("cluster") or settings.asr_cluster + asr_uid = self._runtime_asr.get("uid") or settings.asr_uid + asr_request_params = self._coerce_json_object(self._runtime_asr.get("requestParams")) + if asr_request_params is None: + asr_request_params = self._coerce_json_object(settings.asr_request_params_json) + asr_enable_interim = self._coerce_bool(self._runtime_asr.get("enableInterim")) + if asr_enable_interim is None: + asr_enable_interim = bool(settings.asr_enable_interim) asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms) asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms) + asr_mode = self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")) self.asr_service = self._service_factory.create_asr_service( ASRServiceSpec( provider=asr_provider, sample_rate=settings.sample_rate, + mode=asr_mode, language="auto", api_key=str(asr_api_key).strip() if asr_api_key else None, api_url=str(asr_api_url).strip() if asr_api_url else None, model=str(asr_model).strip() if asr_model else None, + app_id=str(asr_app_id).strip() if asr_app_id else None, + resource_id=str(asr_resource_id).strip() if asr_resource_id else None, + cluster=str(asr_cluster).strip() if asr_cluster else None, + uid=str(asr_uid).strip() if asr_uid else None, + request_params=asr_request_params, + enable_interim=asr_enable_interim, interim_interval_ms=asr_interim_interval, min_audio_for_interim_ms=asr_min_audio_ms, on_transcript=self._on_transcript_callback, ) ) + self._asr_mode = self._resolve_asr_mode( + self._runtime_asr.get("provider") or settings.asr_provider, + getattr(self.asr_service, "mode", self._runtime_asr.get("mode")), + ) await self.asr_service.connect() - logger.info("DuplexPipeline services connected") + logger.info("DuplexPipeline services connected (asr_mode={})", self._asr_mode) if not self._outbound_task or self._outbound_task.done(): self._outbound_task = asyncio.create_task(self._outbound_loop()) @@ -1449,6 +1537,9 @@ class DuplexPipeline: text: Transcribed text is_final: Whether this is the final transcription """ + if not is_final and not self._asr_interim_enabled(): + return + # Avoid sending duplicate transcripts if text == self._last_sent_transcript and not is_final: return @@ -1457,6 +1548,7 @@ class DuplexPipeline: self._last_sent_transcript = text if is_final: + self._latest_asr_interim_text = "" self._pending_transcript_delta = "" self._last_transcript_delta_emit_ms = 0.0 await self._send_event( @@ -1472,6 +1564,7 @@ class DuplexPipeline: logger.debug(f"Sent transcript (final): {text[:50]}...") return + self._latest_asr_interim_text = text self._pending_transcript_delta = text should_emit = ( self._last_transcript_delta_emit_ms <= 0.0 @@ -1495,14 +1588,16 @@ class DuplexPipeline: await self.conversation.start_user_turn() self._audio_buffer = b"" self._last_sent_transcript = "" + self._latest_asr_interim_text = "" self.eou_detector.reset() self._asr_capture_active = False self._asr_capture_started_ms = 0.0 self._pending_speech_audio = b"" - # Clear ASR buffer. Interim starts only after ASR capture is activated. - if hasattr(self.asr_service, 'clear_buffer'): - self.asr_service.clear_buffer() + if self._asr_mode == "streaming": + self._streaming_asr().clear_utterance() + else: + self._offline_asr().clear_buffer() logger.debug("User speech started") @@ -1511,8 +1606,11 @@ class DuplexPipeline: if self._asr_capture_active: return - if hasattr(self.asr_service, 'start_interim_transcription'): - await self.asr_service.start_interim_transcription() + if self._asr_mode == "streaming": + await self._streaming_asr().begin_utterance() + else: + if self._asr_interim_enabled(): + await self._offline_asr().start_interim_transcription() # Prime ASR with a short pre-speech context window so the utterance # start isn't lost while waiting for VAD to transition to Speech. @@ -1545,24 +1643,22 @@ class DuplexPipeline: self._pending_speech_audio = b"" return - # Add a tiny trailing silence tail to stabilize final-token decoding. - if self._asr_final_tail_bytes > 0: - final_tail = b"\x00" * self._asr_final_tail_bytes - await self.asr_service.send_audio(final_tail) - - # Stop interim transcriptions - if hasattr(self.asr_service, 'stop_interim_transcription'): - await self.asr_service.stop_interim_transcription() - - # Get final transcription from ASR service user_text = "" - - if hasattr(self.asr_service, 'get_final_transcription'): - # SiliconFlow ASR - get final transcription - user_text = await self.asr_service.get_final_transcription() - elif hasattr(self.asr_service, 'get_and_clear_text'): - # Buffered ASR - get accumulated text - user_text = self.asr_service.get_and_clear_text() + if self._asr_mode == "streaming": + streaming_asr = self._streaming_asr() + await streaming_asr.end_utterance() + user_text = await streaming_asr.wait_for_final_transcription( + timeout_ms=self._ASR_STREAM_FINAL_TIMEOUT_MS + ) + if not user_text.strip(): + user_text = self._latest_asr_interim_text + else: + # Add a tiny trailing silence tail to stabilize final-token decoding. + if self._asr_final_tail_bytes > 0: + final_tail = b"\x00" * self._asr_final_tail_bytes + await self.asr_service.send_audio(final_tail) + await self._offline_asr().stop_interim_transcription() + user_text = await self._offline_asr().get_final_transcription() # Skip if no meaningful text if not user_text or not user_text.strip(): @@ -1570,6 +1666,7 @@ class DuplexPipeline: # Reset for next utterance self._audio_buffer = b"" self._last_sent_transcript = "" + self._latest_asr_interim_text = "" self._asr_capture_active = False self._asr_capture_started_ms = 0.0 self._pending_speech_audio = b"" @@ -1594,6 +1691,7 @@ class DuplexPipeline: # Clear buffers self._audio_buffer = b"" self._last_sent_transcript = "" + self._latest_asr_interim_text = "" self._pending_transcript_delta = "" self._last_transcript_delta_emit_ms = 0.0 self._asr_capture_active = False diff --git a/engine/runtime/ports/__init__.py b/engine/runtime/ports/__init__.py index a7cbce3..26319b2 100644 --- a/engine/runtime/ports/__init__.py +++ b/engine/runtime/ports/__init__.py @@ -1,6 +1,12 @@ """Port interfaces for runtime integration boundaries.""" -from runtime.ports.asr import ASRBufferControl, ASRInterimControl, ASRPort, ASRServiceSpec +from runtime.ports.asr import ( + ASRMode, + ASRPort, + ASRServiceSpec, + OfflineASRPort, + StreamingASRPort, +) from runtime.ports.control_plane import ( AssistantRuntimeConfigProvider, ControlPlaneGateway, @@ -13,10 +19,11 @@ from runtime.ports.service_factory import RealtimeServiceFactory from runtime.ports.tts import TTSPort, TTSServiceSpec __all__ = [ + "ASRMode", "ASRPort", "ASRServiceSpec", - "ASRInterimControl", - "ASRBufferControl", + "OfflineASRPort", + "StreamingASRPort", "AssistantRuntimeConfigProvider", "ControlPlaneGateway", "ConversationHistoryStore", diff --git a/engine/runtime/ports/asr.py b/engine/runtime/ports/asr.py index 8621ed0..b1310b1 100644 --- a/engine/runtime/ports/asr.py +++ b/engine/runtime/ports/asr.py @@ -3,11 +3,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import AsyncIterator, Awaitable, Callable, Optional, Protocol +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional, Protocol from providers.common.base import ASRResult TranscriptCallback = Callable[[str, bool], Awaitable[None]] +ASRMode = Literal["offline", "streaming"] @dataclass(frozen=True) @@ -16,10 +17,17 @@ class ASRServiceSpec: provider: str sample_rate: int + mode: Optional[ASRMode] = None language: str = "auto" api_key: Optional[str] = None api_url: Optional[str] = None model: Optional[str] = None + app_id: Optional[str] = None + resource_id: Optional[str] = None + cluster: Optional[str] = None + uid: Optional[str] = None + request_params: Optional[Dict[str, Any]] = None + enable_interim: bool = False interim_interval_ms: int = 500 min_audio_for_interim_ms: int = 300 on_transcript: Optional[TranscriptCallback] = None @@ -28,6 +36,8 @@ class ASRServiceSpec: class ASRPort(Protocol): """Port for speech recognition providers.""" + mode: ASRMode + async def connect(self) -> None: """Establish connection to ASR provider.""" @@ -41,18 +51,16 @@ class ASRPort(Protocol): """Stream partial/final recognition results.""" -class ASRInterimControl(Protocol): - """Optional extension for explicit interim transcription control.""" +class OfflineASRPort(ASRPort, Protocol): + """Port for offline/buffered ASR providers.""" + + mode: Literal["offline"] async def start_interim_transcription(self) -> None: - """Start interim transcription loop if supported.""" + """Start interim transcription loop.""" async def stop_interim_transcription(self) -> None: - """Stop interim transcription loop if supported.""" - - -class ASRBufferControl(Protocol): - """Optional extension for explicit ASR buffer lifecycle control.""" + """Stop interim transcription loop.""" def clear_buffer(self) -> None: """Clear provider-side ASR buffer.""" @@ -62,3 +70,21 @@ class ASRBufferControl(Protocol): def get_and_clear_text(self) -> str: """Return buffered text and clear internal state.""" + + +class StreamingASRPort(ASRPort, Protocol): + """Port for streaming ASR providers.""" + + mode: Literal["streaming"] + + async def begin_utterance(self) -> None: + """Start a new utterance stream.""" + + async def end_utterance(self) -> None: + """Signal end of current utterance stream.""" + + async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str: + """Wait for final transcript after utterance end.""" + + def clear_utterance(self) -> None: + """Reset utterance-local state.""" diff --git a/engine/runtime/ports/tts.py b/engine/runtime/ports/tts.py index 523dc3c..a98e17d 100644 --- a/engine/runtime/ports/tts.py +++ b/engine/runtime/ports/tts.py @@ -19,6 +19,10 @@ class TTSServiceSpec: api_key: Optional[str] = None api_url: Optional[str] = None model: Optional[str] = None + app_id: Optional[str] = None + resource_id: Optional[str] = None + cluster: Optional[str] = None + uid: Optional[str] = None mode: str = "commit" diff --git a/engine/tests/test_asr_factory_modes.py b/engine/tests/test_asr_factory_modes.py new file mode 100644 index 0000000..5cd78f8 --- /dev/null +++ b/engine/tests/test_asr_factory_modes.py @@ -0,0 +1,71 @@ +from providers.asr.buffered import BufferedASRService +from providers.asr.dashscope import DashScopeRealtimeASRService +from providers.asr.openai_compatible import OpenAICompatibleASRService +from providers.asr.volcengine import VolcengineRealtimeASRService +from providers.factory.default import DefaultRealtimeServiceFactory +from runtime.ports import ASRServiceSpec + + +def test_create_asr_service_dashscope_returns_streaming_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_asr_service( + ASRServiceSpec( + provider="dashscope", + mode="streaming", + sample_rate=16000, + api_key="test-key", + model="qwen3-asr-flash-realtime", + ) + ) + assert isinstance(service, DashScopeRealtimeASRService) + assert service.mode == "streaming" + + +def test_create_asr_service_openai_compatible_returns_offline_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_asr_service( + ASRServiceSpec( + provider="openai_compatible", + sample_rate=16000, + api_key="test-key", + model="FunAudioLLM/SenseVoiceSmall", + ) + ) + assert isinstance(service, OpenAICompatibleASRService) + assert service.mode == "offline" + assert service.enable_interim is False + + +def test_create_asr_service_volcengine_returns_streaming_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_asr_service( + ASRServiceSpec( + provider="volcengine", + mode="streaming", + sample_rate=16000, + api_key="test-key", + api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + model="bigmodel", + app_id="app-1", + uid="caller-1", + request_params={"end_window_size": 800}, + ) + ) + assert isinstance(service, VolcengineRealtimeASRService) + assert service.mode == "streaming" + assert service.protocol == "seed" + assert service.app_id == "app-1" + assert service.uid == "caller-1" + assert service.request_params["end_window_size"] == 800 + + +def test_create_asr_service_fallback_buffered_for_unsupported_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_asr_service( + ASRServiceSpec( + provider="unknown_provider", + sample_rate=16000, + ) + ) + assert isinstance(service, BufferedASRService) + assert service.mode == "offline" diff --git a/engine/tests/test_backend_adapters.py b/engine/tests/test_backend_adapters.py index 70f569e..9cce105 100644 --- a/engine/tests/test_backend_adapters.py +++ b/engine/tests/test_backend_adapters.py @@ -227,6 +227,62 @@ async def test_with_backend_url_uses_backend_for_assistant_config(monkeypatch, t assert payload["assistant"]["systemPrompt"] == "backend prompt" +def test_translate_agent_schema_maps_volcengine_fields(): + payload = { + "agent": { + "tts": { + "provider": "volcengine", + "api_key": "tts-key", + "api_url": "https://openspeech.bytedance.com/api/v3/tts/unidirectional", + "app_id": "app-123", + "resource_id": "seed-tts-2.0", + "uid": "caller-1", + "voice": "zh_female_shuangkuaisisi_moon_bigtts", + "speed": 1.1, + }, + "asr": { + "provider": "volcengine", + "api_key": "asr-key", + "api_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + "model": "bigmodel", + "app_id": "app-123", + "resource_id": "volc.bigasr.sauc.duration", + "uid": "caller-1", + "request_params": { + "end_window_size": 800, + "force_to_speech_time": 1000, + }, + }, + } + } + + translated = LocalYamlAssistantConfigAdapter._translate_agent_schema("assistant_demo", payload) + assert translated is not None + assert translated["services"]["tts"] == { + "provider": "volcengine", + "apiKey": "tts-key", + "baseUrl": "https://openspeech.bytedance.com/api/v3/tts/unidirectional", + "voice": "zh_female_shuangkuaisisi_moon_bigtts", + "appId": "app-123", + "resourceId": "seed-tts-2.0", + "uid": "caller-1", + "speed": 1.1, + } + assert translated["services"]["asr"] == { + "provider": "volcengine", + "model": "bigmodel", + "apiKey": "asr-key", + "baseUrl": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + "appId": "app-123", + "resourceId": "volc.bigasr.sauc.duration", + "uid": "caller-1", + "requestParams": { + "end_window_size": 800, + "force_to_speech_time": 1000, + }, + } + + @pytest.mark.asyncio async def test_backend_mode_disabled_uses_local_assistant_config_even_with_url(monkeypatch, tmp_path): class _FailIfCalledClientSession: @@ -282,7 +338,7 @@ async def test_local_yaml_adapter_rejects_path_traversal_like_assistant_id(tmp_p @pytest.mark.asyncio -async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path): +async def test_local_yaml_translates_agent_schema_with_asr_interim_flag(tmp_path): config_dir = tmp_path / "assistants" config_dir.mkdir(parents=True, exist_ok=True) (config_dir / "default.yaml").write_text( @@ -305,6 +361,7 @@ async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path): " model: asr-model", " api_key: sk-asr", " api_url: https://asr.example.com/v1/audio/transcriptions", + " enable_interim: false", " duplex:", " system_prompt: You are test assistant", ] @@ -321,4 +378,5 @@ async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path): assert services.get("llm", {}).get("apiKey") == "sk-llm" assert services.get("tts", {}).get("apiKey") == "sk-tts" assert services.get("asr", {}).get("apiKey") == "sk-asr" + assert services.get("asr", {}).get("enableInterim") is False assert assistant.get("systemPrompt") == "You are test assistant" diff --git a/engine/tests/test_dashscope_asr_provider.py b/engine/tests/test_dashscope_asr_provider.py new file mode 100644 index 0000000..123530a --- /dev/null +++ b/engine/tests/test_dashscope_asr_provider.py @@ -0,0 +1,67 @@ +import asyncio + +import pytest + +from providers.asr.dashscope import DashScopeRealtimeASRService + + +@pytest.mark.asyncio +async def test_dashscope_asr_interim_event_emits_interim_transcript(): + received = [] + + async def _on_transcript(text: str, is_final: bool) -> None: + received.append((text, is_final)) + + service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript) + service._loop = asyncio.get_running_loop() + service._running = True + + service._on_ws_event( + { + "type": "conversation.item.input_audio_transcription.text", + "stash": "你好世界", + } + ) + await asyncio.sleep(0.05) + + result = service._transcript_queue.get_nowait() + assert result.text == "你好世界" + assert result.is_final is False + assert received == [("你好世界", False)] + + +@pytest.mark.asyncio +async def test_dashscope_asr_final_event_emits_final_transcript_and_final_queue(): + received = [] + + async def _on_transcript(text: str, is_final: bool) -> None: + received.append((text, is_final)) + + service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript) + service._loop = asyncio.get_running_loop() + service._running = True + service._audio_sent_in_utterance = True + + service._on_ws_event( + { + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "最终识别结果", + } + ) + await asyncio.sleep(0.05) + + result = service._transcript_queue.get_nowait() + assert result.text == "最终识别结果" + assert result.is_final is True + assert service._final_queue.get_nowait() == "最终识别结果" + assert received == [("最终识别结果", True)] + + +@pytest.mark.asyncio +async def test_dashscope_wait_for_final_falls_back_to_latest_interim_on_timeout(): + service = DashScopeRealtimeASRService(api_key="test-key") + service._audio_sent_in_utterance = True + service._last_interim_text = "部分结果" + + text = await service.wait_for_final_transcription(timeout_ms=10) + assert text == "部分结果" diff --git a/engine/tests/test_duplex_asr_modes.py b/engine/tests/test_duplex_asr_modes.py new file mode 100644 index 0000000..3e4b1cf --- /dev/null +++ b/engine/tests/test_duplex_asr_modes.py @@ -0,0 +1,260 @@ +import asyncio +from typing import Any, Dict, List + +import pytest + +from runtime.pipeline.duplex import DuplexPipeline + + +class _DummySileroVAD: + def __init__(self, *args, **kwargs): + pass + + def process_audio(self, _pcm: bytes) -> float: + return 0.0 + + +class _DummyVADProcessor: + def __init__(self, *args, **kwargs): + pass + + def process(self, _speech_prob: float): + return "Silence", 0.0 + + +class _DummyEouDetector: + def __init__(self, *args, **kwargs): + self.is_speaking = True + + def process(self, _vad_status: str, force_eligible: bool = False) -> bool: + _ = force_eligible + return False + + def reset(self) -> None: + self.is_speaking = False + + +class _FakeTransport: + async def send_event(self, _event: Dict[str, Any]) -> None: + return None + + async def send_audio(self, _audio: bytes) -> None: + return None + + +class _FakeStreamingASR: + mode = "streaming" + + def __init__(self): + self.begin_calls = 0 + self.end_calls = 0 + self.wait_calls = 0 + self.sent_audio: List[bytes] = [] + self.wait_text = "" + + async def connect(self) -> None: + return None + + async def disconnect(self) -> None: + return None + + async def send_audio(self, audio: bytes) -> None: + self.sent_audio.append(audio) + + async def receive_transcripts(self): + if False: + yield None + + async def begin_utterance(self) -> None: + self.begin_calls += 1 + + async def end_utterance(self) -> None: + self.end_calls += 1 + + async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str: + _ = timeout_ms + self.wait_calls += 1 + return self.wait_text + + def clear_utterance(self) -> None: + return None + + +class _FakeOfflineASR: + mode = "offline" + + def __init__(self): + self.start_interim_calls = 0 + self.stop_interim_calls = 0 + self.sent_audio: List[bytes] = [] + self.final_text = "offline final" + + async def connect(self) -> None: + return None + + async def disconnect(self) -> None: + return None + + async def send_audio(self, audio: bytes) -> None: + self.sent_audio.append(audio) + + async def receive_transcripts(self): + if False: + yield None + + async def start_interim_transcription(self) -> None: + self.start_interim_calls += 1 + + async def stop_interim_transcription(self) -> None: + self.stop_interim_calls += 1 + + async def get_final_transcription(self) -> str: + return self.final_text + + def clear_buffer(self) -> None: + return None + + def get_and_clear_text(self) -> str: + return self.final_text + + +def _build_pipeline(monkeypatch, asr_service): + monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD) + monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor) + monkeypatch.setattr("runtime.pipeline.duplex.EouDetector", _DummyEouDetector) + return DuplexPipeline( + transport=_FakeTransport(), + session_id="asr_mode_test", + asr_service=asr_service, + ) + + +@pytest.mark.asyncio +async def test_start_asr_capture_uses_streaming_begin(monkeypatch): + asr = _FakeStreamingASR() + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "streaming" + pipeline._pending_speech_audio = b"\x00" * 320 + pipeline._pre_speech_buffer = b"\x00" * 640 + + await pipeline._start_asr_capture() + + assert asr.begin_calls == 1 + assert asr.sent_audio + assert pipeline._asr_capture_active is True + + +@pytest.mark.asyncio +async def test_start_asr_capture_uses_offline_interim_control_when_enabled(monkeypatch): + asr = _FakeOfflineASR() + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "offline" + pipeline._runtime_asr["enableInterim"] = True + pipeline._pending_speech_audio = b"\x00" * 320 + pipeline._pre_speech_buffer = b"\x00" * 640 + + await pipeline._start_asr_capture() + + assert asr.start_interim_calls == 1 + assert asr.sent_audio + assert pipeline._asr_capture_active is True + + +@pytest.mark.asyncio +async def test_start_asr_capture_skips_offline_interim_control_when_disabled(monkeypatch): + asr = _FakeOfflineASR() + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "offline" + pipeline._runtime_asr["enableInterim"] = False + pipeline._pending_speech_audio = b"\x00" * 320 + pipeline._pre_speech_buffer = b"\x00" * 640 + + await pipeline._start_asr_capture() + + assert asr.start_interim_calls == 0 + assert asr.sent_audio + assert pipeline._asr_capture_active is True + + +@pytest.mark.asyncio +async def test_offline_interim_callback_ignored_when_disabled(monkeypatch): + asr = _FakeOfflineASR() + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "offline" + pipeline._runtime_asr["enableInterim"] = False + + captured_events = [] + captured_deltas = [] + + async def _capture_event(event: Dict[str, Any], priority: int = 20): + _ = priority + captured_events.append(event) + + async def _capture_delta(text: str): + captured_deltas.append(text) + + monkeypatch.setattr(pipeline, "_send_event", _capture_event) + monkeypatch.setattr(pipeline, "_emit_transcript_delta", _capture_delta) + + await pipeline._on_transcript_callback("ignored interim", is_final=False) + + assert captured_events == [] + assert captured_deltas == [] + assert pipeline._latest_asr_interim_text == "" + + +@pytest.mark.asyncio +async def test_offline_final_callback_emits_when_interim_disabled(monkeypatch): + asr = _FakeOfflineASR() + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "offline" + pipeline._runtime_asr["enableInterim"] = False + + captured_events = [] + + async def _capture_event(event: Dict[str, Any], priority: int = 20): + _ = priority + captured_events.append(event) + + monkeypatch.setattr(pipeline, "_send_event", _capture_event) + + await pipeline._on_transcript_callback("final only", is_final=True) + + assert any(event.get("type") == "transcript.final" for event in captured_events) + + +@pytest.mark.asyncio +async def test_streaming_eou_falls_back_to_latest_interim(monkeypatch): + asr = _FakeStreamingASR() + asr.wait_text = "" + pipeline = _build_pipeline(monkeypatch, asr) + pipeline._asr_mode = "streaming" + pipeline._asr_capture_active = True + pipeline._latest_asr_interim_text = "fallback interim text" + await pipeline.conversation.start_user_turn() + + captured_events = [] + captured_turns = [] + + async def _capture_event(event: Dict[str, Any], priority: int = 20): + _ = priority + captured_events.append(event) + + async def _noop_stop_current_speech() -> None: + return None + + async def _capture_turn(user_text: str, *args, **kwargs) -> None: + _ = (args, kwargs) + captured_turns.append(user_text) + + monkeypatch.setattr(pipeline, "_send_event", _capture_event) + monkeypatch.setattr(pipeline, "_stop_current_speech", _noop_stop_current_speech) + monkeypatch.setattr(pipeline, "_handle_turn", _capture_turn) + + await pipeline._on_end_of_utterance() + await asyncio.sleep(0.05) + + assert asr.end_calls == 1 + assert asr.wait_calls == 1 + assert captured_turns == ["fallback interim text"] + assert any(event.get("type") == "transcript.final" for event in captured_events) diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index d820643..717f96a 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -52,9 +52,33 @@ class _FakeTTS: class _FakeASR: + mode = "offline" + async def connect(self) -> None: return None + async def disconnect(self) -> None: + return None + + async def send_audio(self, _audio: bytes) -> None: + return None + + async def receive_transcripts(self): + if False: + yield None + + def clear_buffer(self) -> None: + return None + + async def start_interim_transcription(self) -> None: + return None + + async def stop_interim_transcription(self) -> None: + return None + + async def get_final_transcription(self) -> str: + return "" + class _FakeLLM: def __init__(self, rounds: List[List[LLMStreamEvent]]): diff --git a/engine/tests/test_tts_factory_modes.py b/engine/tests/test_tts_factory_modes.py new file mode 100644 index 0000000..987fc10 --- /dev/null +++ b/engine/tests/test_tts_factory_modes.py @@ -0,0 +1,45 @@ +from providers.factory.default import DefaultRealtimeServiceFactory +from providers.tts.mock import MockTTSService +from providers.tts.openai_compatible import OpenAICompatibleTTSService +from providers.tts.volcengine import VolcengineTTSService +from runtime.ports import TTSServiceSpec + + +def test_create_tts_service_volcengine_returns_native_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_tts_service( + TTSServiceSpec( + provider="volcengine", + api_key="test-key", + app_id="app-1", + resource_id="seed-tts-2.0", + voice="zh_female_shuangkuaisisi_moon_bigtts", + sample_rate=16000, + ) + ) + assert isinstance(service, VolcengineTTSService) + + +def test_create_tts_service_openai_compatible_returns_provider(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_tts_service( + TTSServiceSpec( + provider="openai_compatible", + api_key="test-key", + voice="anna", + sample_rate=16000, + ) + ) + assert isinstance(service, OpenAICompatibleTTSService) + + +def test_create_tts_service_fallbacks_to_mock_without_key(): + factory = DefaultRealtimeServiceFactory() + service = factory.create_tts_service( + TTSServiceSpec( + provider="volcengine", + voice="anna", + sample_rate=16000, + ) + ) + assert isinstance(service, MockTTSService) diff --git a/engine/tests/test_volcengine_asr_provider.py b/engine/tests/test_volcengine_asr_provider.py new file mode 100644 index 0000000..c5756c0 --- /dev/null +++ b/engine/tests/test_volcengine_asr_provider.py @@ -0,0 +1,86 @@ +import gzip +import json + +from providers.asr.volcengine import VolcengineRealtimeASRService + + +def test_volcengine_seed_protocol_defaults_and_headers(): + service = VolcengineRealtimeASRService( + api_key="access-token", + api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + app_id="app-1", + uid="caller-1", + ) + + assert service.protocol == "seed" + assert service.resource_id == "volc.bigasr.sauc.duration" + + headers = service._build_seed_headers("req-1") + assert headers == { + "X-Api-App-Key": "app-1", + "X-Api-Access-Key": "access-token", + "X-Api-Resource-Id": "volc.bigasr.sauc.duration", + "X-Api-Request-Id": "req-1", + } + + +def test_volcengine_seed_start_payload_merges_request_params(): + service = VolcengineRealtimeASRService( + api_key="access-token", + api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + app_id="app-1", + uid="caller-1", + language="zh-CN", + request_params={ + "request": { + "end_window_size": 800, + "force_to_speech_time": 1000, + "context": "{\"hotwords\":[{\"word\":\"doubao\"}]}", + }, + "audio": {"codec": "raw"}, + }, + ) + + payload = service._build_seed_start_payload() + assert payload["user"] == {"uid": "caller-1"} + assert payload["audio"] == { + "format": "pcm", + "rate": 16000, + "bits": 16, + "channels": 1, + "codec": "raw", + "language": "zh-CN", + } + assert payload["request"]["model_name"] == "bigmodel" + assert payload["request"]["end_window_size"] == 800 + assert payload["request"]["force_to_speech_time"] == 1000 + assert payload["request"]["context"] == "{\"hotwords\":[{\"word\":\"doubao\"}]}" + + +def test_volcengine_seed_start_request_encodes_gzip_json_payload(): + service = VolcengineRealtimeASRService( + api_key="access-token", + api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", + app_id="app-1", + uid="caller-1", + ) + + frame = service._build_seed_start_request() + assert frame[0] == 0x11 + assert frame[1] == 0x11 + + payload_length = int.from_bytes(frame[8:12], "big") + payload = json.loads(gzip.decompress(frame[12 : 12 + payload_length]).decode("utf-8")) + assert payload["user"]["uid"] == "caller-1" + assert payload["request"]["model_name"] == "bigmodel" + + +def test_volcengine_gateway_protocol_keeps_model_query(): + service = VolcengineRealtimeASRService( + api_key="access-token", + api_url="wss://ai-gateway.vei.volces.com/v1/realtime", + model="bigmodel", + ) + + assert service.protocol == "gateway" + assert service.api_url == "wss://ai-gateway.vei.volces.com/v1/realtime?model=bigmodel" diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index 0826745..9545ab4 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -259,6 +259,7 @@ export const AssistantsPage: React.FC = () => { speed: 1, hotwords: [], tools: [], + asrInterimEnabled: false, botCannotBeInterrupted: false, interruptionSensitivity: 180, configMode: 'platform', @@ -1358,6 +1359,41 @@ export const AssistantsPage: React.FC = () => {

+
+
+ +
+ + +
+
+

+ 仅影响离线 ASR 模式(OpenAI Compatible / buffered)。默认关闭。 +

+
+