This commit is contained in:
Xin Wang
2026-03-09 05:41:13 +08:00
37 changed files with 2497 additions and 50 deletions

View File

@@ -127,6 +127,7 @@ class Assistant(Base):
speed: Mapped[float] = mapped_column(Float, default=1.0)
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
tools: Mapped[dict] = mapped_column(JSON, default=list)
asr_interim_enabled: Mapped[bool] = mapped_column(default=False)
bot_cannot_be_interrupted: Mapped[bool] = mapped_column(default=False)
interruption_sensitivity: Mapped[int] = mapped_column(Integer, default=500)
config_mode: Mapped[str] = mapped_column(String(32), default="platform")

View File

@@ -126,6 +126,9 @@ def _ensure_assistant_schema(db: Session) -> None:
if "manual_opener_tool_calls" not in columns:
db.execute(text("ALTER TABLE assistants ADD COLUMN manual_opener_tool_calls JSON"))
altered = True
if "asr_interim_enabled" not in columns:
db.execute(text("ALTER TABLE assistants ADD COLUMN asr_interim_enabled BOOLEAN DEFAULT 0"))
altered = True
if altered:
db.commit()
@@ -317,18 +320,27 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
else:
warnings.append(f"LLM model not found: {assistant.llm_model_id}")
asr_runtime: Dict[str, Any] = {
"enableInterim": bool(assistant.asr_interim_enabled),
}
if assistant.asr_model_id:
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
if asr:
asr_provider = "openai_compatible" if _is_openai_compatible_vendor(asr.vendor) else "buffered"
metadata["services"]["asr"] = {
if _is_dashscope_vendor(asr.vendor):
asr_provider = "dashscope"
elif _is_openai_compatible_vendor(asr.vendor):
asr_provider = "openai_compatible"
else:
asr_provider = "buffered"
asr_runtime.update({
"provider": asr_provider,
"model": asr.model_name or asr.name,
"apiKey": asr.api_key if asr_provider == "openai_compatible" else None,
"baseUrl": asr.base_url if asr_provider == "openai_compatible" else None,
}
"apiKey": asr.api_key if asr_provider in {"openai_compatible", "dashscope"} else None,
"baseUrl": asr.base_url if asr_provider in {"openai_compatible", "dashscope"} else None,
})
else:
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
metadata["services"]["asr"] = asr_runtime
if not assistant.voice_output_enabled:
metadata["services"]["tts"] = {"enabled": False}
@@ -432,6 +444,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
"speed": assistant.speed,
"hotwords": assistant.hotwords or [],
"tools": _normalize_assistant_tool_ids(assistant.tools),
"asrInterimEnabled": bool(assistant.asr_interim_enabled),
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
"interruptionSensitivity": assistant.interruption_sensitivity,
"configMode": assistant.config_mode,
@@ -452,6 +465,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
"firstTurnMode": "first_turn_mode",
"manualOpenerToolCalls": "manual_opener_tool_calls",
"interruptionSensitivity": "interruption_sensitivity",
"asrInterimEnabled": "asr_interim_enabled",
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
"configMode": "config_mode",
"voiceOutputEnabled": "voice_output_enabled",
@@ -646,6 +660,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
speed=data.speed,
hotwords=data.hotwords,
tools=_normalize_assistant_tool_ids(data.tools),
asr_interim_enabled=data.asrInterimEnabled,
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
interruption_sensitivity=data.interruptionSensitivity,
config_mode=data.configMode,

View File

@@ -291,6 +291,7 @@ class AssistantBase(BaseModel):
speed: float = 1.0
hotwords: List[str] = []
tools: List[str] = []
asrInterimEnabled: bool = False
botCannotBeInterrupted: bool = False
interruptionSensitivity: int = 500
configMode: str = "platform"
@@ -322,6 +323,7 @@ class AssistantUpdate(BaseModel):
speed: Optional[float] = None
hotwords: Optional[List[str]] = None
tools: Optional[List[str]] = None
asrInterimEnabled: Optional[bool] = None
botCannotBeInterrupted: Optional[bool] = None
interruptionSensitivity: Optional[int] = None
configMode: Optional[str] = None

View File

@@ -27,6 +27,7 @@ class TestAssistantAPI:
assert data["voiceOutputEnabled"] is True
assert data["firstTurnMode"] == "bot_first"
assert data["generatedOpenerEnabled"] is False
assert data["asrInterimEnabled"] is False
assert data["botCannotBeInterrupted"] is False
assert "id" in data
assert data["callCount"] == 0
@@ -37,6 +38,7 @@ class TestAssistantAPI:
response = client.post("/api/assistants", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Minimal Assistant"
assert response.json()["asrInterimEnabled"] is False
def test_get_assistant_by_id(self, client, sample_assistant_data):
"""Test getting a specific assistant by ID"""
@@ -68,6 +70,7 @@ class TestAssistantAPI:
"prompt": "You are an updated assistant.",
"speed": 1.5,
"voiceOutputEnabled": False,
"asrInterimEnabled": True,
"manualOpenerToolCalls": [
{"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}}
],
@@ -79,6 +82,7 @@ class TestAssistantAPI:
assert data["prompt"] == "You are an updated assistant."
assert data["speed"] == 1.5
assert data["voiceOutputEnabled"] is False
assert data["asrInterimEnabled"] is True
assert data["manualOpenerToolCalls"] == [
{"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}}
]
@@ -213,6 +217,7 @@ class TestAssistantAPI:
"prompt": "runtime prompt",
"opener": "runtime opener",
"manualOpenerToolCalls": [{"toolName": "text_msg_prompt", "arguments": {"msg": "欢迎"}}],
"asrInterimEnabled": True,
"speed": 1.1,
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
@@ -232,6 +237,7 @@ class TestAssistantAPI:
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"]
assert metadata["services"]["asr"]["enableInterim"] is True
expected_tts_voice = f"{sample_voice_data['model']}:{sample_voice_data['voice_key']}"
assert metadata["services"]["tts"]["voice"] == expected_tts_voice
assert metadata["services"]["tts"]["baseUrl"] == sample_voice_data["base_url"]
@@ -309,6 +315,7 @@ class TestAssistantAPI:
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["output"]["mode"] == "text"
assert metadata["services"]["asr"]["enableInterim"] is False
assert metadata["services"]["tts"]["enabled"] is False
def test_runtime_config_dashscope_voice_provider(self, client, sample_assistant_data):
@@ -343,6 +350,48 @@ class TestAssistantAPI:
assert tts["apiKey"] == "dashscope-key"
assert tts["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
def test_runtime_config_dashscope_asr_provider(self, client, sample_assistant_data):
"""DashScope ASR models should map to dashscope asr provider in runtime metadata."""
asr_resp = client.post("/api/asr", json={
"name": "DashScope Realtime ASR",
"vendor": "DashScope",
"language": "zh",
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
"api_key": "dashscope-asr-key",
"model_name": "qwen3-asr-flash-realtime",
"hotwords": [],
"enable_punctuation": True,
"enable_normalization": True,
"enabled": True,
})
assert asr_resp.status_code == 200
asr_payload = asr_resp.json()
sample_assistant_data.update({
"asrModelId": asr_payload["id"],
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
asr = metadata["services"]["asr"]
assert asr["provider"] == "dashscope"
assert asr["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
assert asr["enableInterim"] is False
def test_runtime_config_defaults_asr_interim_disabled_without_asr_model(self, client, sample_assistant_data):
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["services"]["asr"]["enableInterim"] is False
def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data):
sample_assistant_data.update({
"firstTurnMode": "user_first",

View File

@@ -12,6 +12,24 @@
| **热词** | 提高业务词汇、品牌词、专有名词识别率 |
| **标点与规范化** | 自动补全标点、规范数字和日期等 |
## 模式
- `offline`:引擎本地缓冲音频后触发识别(适用于 OpenAI-compatible / SiliconFlow
- `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR、Volcengine BigASR
## 配置项
| 配置项 | 说明 |
|---|---|
| ASR 引擎 | 选择语音识别服务提供商 |
| 模型 | 识别模型名称 |
| `enable_interim` | 是否开启离线 ASR 中间结果(默认 `false`,仅离线模式生效) |
| `app_id` / `resource_id` | Volcengine 等厂商的应用标识与资源标识 |
| `request_params` | 厂商原生请求参数透传,例如 `end_window_size``force_to_speech_time``context` |
| 语言 | 中文/英文/多语言 |
| 热词 | 提升特定词汇识别准确率 |
| 标点与规范化 | 是否自动补全标点、文本规范化 |
## 选择建议
- 客服、外呼等业务场景建议维护热词表,并按业务线持续更新
@@ -29,3 +47,7 @@
- [声音资源](voices.md) - 完整语音输入输出链路中的 TTS 侧配置
- [快速开始](../quickstart/index.md) - 以任务路径接入第一个 ASR 资源
- 客服场景建议开启热词并维护业务词表
- 多语言场景建议按会话入口显式指定语言
- 对延迟敏感场景优先选择流式识别模型
- 当前支持提供商:`openai_compatible``siliconflow``dashscope``volcengine``buffered`(回退)

View File

@@ -230,6 +230,14 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter):
tts_runtime["baseUrl"] = cls._as_str(tts.get("api_url"))
if cls._as_str(tts.get("voice")):
tts_runtime["voice"] = cls._as_str(tts.get("voice"))
if cls._as_str(tts.get("app_id")):
tts_runtime["appId"] = cls._as_str(tts.get("app_id"))
if cls._as_str(tts.get("resource_id")):
tts_runtime["resourceId"] = cls._as_str(tts.get("resource_id"))
if cls._as_str(tts.get("cluster")):
tts_runtime["cluster"] = cls._as_str(tts.get("cluster"))
if cls._as_str(tts.get("uid")):
tts_runtime["uid"] = cls._as_str(tts.get("uid"))
if tts.get("speed") is not None:
tts_runtime["speed"] = tts.get("speed")
dashscope_mode = cls._as_str(tts.get("dashscope_mode")) or cls._as_str(tts.get("mode"))
@@ -249,6 +257,18 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter):
asr_runtime["apiKey"] = cls._as_str(asr.get("api_key"))
if cls._as_str(asr.get("api_url")):
asr_runtime["baseUrl"] = cls._as_str(asr.get("api_url"))
if cls._as_str(asr.get("app_id")):
asr_runtime["appId"] = cls._as_str(asr.get("app_id"))
if cls._as_str(asr.get("resource_id")):
asr_runtime["resourceId"] = cls._as_str(asr.get("resource_id"))
if cls._as_str(asr.get("cluster")):
asr_runtime["cluster"] = cls._as_str(asr.get("cluster"))
if cls._as_str(asr.get("uid")):
asr_runtime["uid"] = cls._as_str(asr.get("uid"))
if isinstance(asr.get("request_params"), dict):
asr_runtime["requestParams"] = dict(asr.get("request_params") or {})
if asr.get("enable_interim") is not None:
asr_runtime["enableInterim"] = asr.get("enable_interim")
if asr.get("interim_interval_ms") is not None:
asr_runtime["interimIntervalMs"] = asr.get("interim_interval_ms")
if asr.get("min_audio_ms") is not None:

View File

@@ -71,11 +71,15 @@ class Settings(BaseSettings):
# TTS Configuration
tts_provider: str = Field(
default="openai_compatible",
description="TTS provider (openai_compatible, siliconflow, dashscope)"
description="TTS provider (openai_compatible, siliconflow, dashscope, volcengine)"
)
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
tts_model: Optional[str] = Field(default=None, description="TTS model name")
tts_voice: str = Field(default="anna", description="TTS voice name")
tts_app_id: Optional[str] = Field(default=None, description="Provider-specific TTS app ID")
tts_resource_id: Optional[str] = Field(default=None, description="Provider-specific TTS resource ID")
tts_cluster: Optional[str] = Field(default=None, description="Provider-specific TTS cluster")
tts_uid: Optional[str] = Field(default=None, description="Provider-specific TTS user ID")
tts_mode: str = Field(
default="commit",
description="DashScope-only TTS mode (commit, server_commit). Ignored for non-dashscope providers."
@@ -85,10 +89,19 @@ class Settings(BaseSettings):
# ASR Configuration
asr_provider: str = Field(
default="openai_compatible",
description="ASR provider (openai_compatible, buffered, siliconflow)"
description="ASR provider (openai_compatible, buffered, siliconflow, dashscope, volcengine)"
)
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
asr_model: Optional[str] = Field(default=None, description="ASR model name")
asr_app_id: Optional[str] = Field(default=None, description="Provider-specific ASR app ID")
asr_resource_id: Optional[str] = Field(default=None, description="Provider-specific ASR resource ID")
asr_cluster: Optional[str] = Field(default=None, description="Provider-specific ASR cluster")
asr_uid: Optional[str] = Field(default=None, description="Provider-specific ASR user ID")
asr_request_params_json: Optional[str] = Field(
default=None,
description="Provider-specific ASR request params as JSON string"
)
asr_enable_interim: bool = Field(default=False, description="Enable interim transcripts for offline ASR")
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
asr_start_min_speech_ms: int = Field(

View File

@@ -0,0 +1,47 @@
# Agent behavior configuration for DashScope realtime ASR/TTS.
# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers).
# Infra/server/network settings should stay in .env.
agent:
vad:
type: silero
model_path: data/vad/silero_vad.onnx
threshold: 0.5
min_speech_duration_ms: 100
eou_threshold_ms: 800
llm:
# provider: openai | openai_compatible | siliconflow
provider: openai_compatible
model: deepseek-v3
temperature: 0.7
api_key: your_llm_api_key
api_url: https://api.qnaigc.com/v1
tts:
provider: dashscope
api_key: your_tts_api_key
api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
model: qwen3-tts-flash-realtime
voice: Cherry
dashscope_mode: commit
speed: 1.0
asr:
provider: dashscope
api_key: your_asr_api_key
api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
model: qwen3-asr-flash-realtime
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160
pre_speech_ms: 240
final_tail_ms: 120
duplex:
enabled: true
system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
barge_in:
min_duration_ms: 200
silence_tolerance_ms: 60

View File

@@ -21,12 +21,17 @@ agent:
api_url: https://api.qnaigc.com/v1
tts:
# provider: openai_compatible | siliconflow | dashscope
# provider: openai_compatible | siliconflow | dashscope | volcengine
# dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-tts-flash-realtime
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
# note: dashscope_mode/mode is ONLY used when provider=dashscope.
# volcengine defaults (if omitted):
# api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional
# resource_id: seed-tts-2.0
# app_id: your volcengine app key
# api_key: your volcengine access key
provider: openai_compatible
api_key: your_tts_api_key
api_url: https://api.siliconflow.cn/v1/audio/speech
@@ -35,11 +40,26 @@ agent:
speed: 1.0
asr:
# provider: buffered | openai_compatible | siliconflow
# provider: buffered | openai_compatible | siliconflow | dashscope | volcengine
# dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-asr-flash-realtime
# note: dashscope uses streaming ASR mode (chunk-by-chunk).
# volcengine defaults (if omitted):
# api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
# model: bigmodel
# resource_id: volc.bigasr.sauc.duration
# app_id: your volcengine app key
# api_key: your volcengine access key
# request_params:
# end_window_size: 800
# force_to_speech_time: 1000
# note: volcengine uses streaming ASR mode (chunk-by-chunk).
provider: openai_compatible
api_key: you_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
enable_interim: false
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160

View File

@@ -18,12 +18,17 @@ agent:
api_url: https://api.qnaigc.com/v1
tts:
# provider: openai_compatible | siliconflow | dashscope
# provider: openai_compatible | siliconflow | dashscope | volcengine
# dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-tts-flash-realtime
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
# note: dashscope_mode/mode is ONLY used when provider=dashscope.
# volcengine defaults (if omitted):
# api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional
# resource_id: seed-tts-2.0
# app_id: your volcengine app key
# api_key: your volcengine access key
provider: openai_compatible
api_key: your_tts_api_key
api_url: https://api.siliconflow.cn/v1/audio/speech
@@ -32,11 +37,26 @@ agent:
speed: 1.0
asr:
# provider: buffered | openai_compatible | siliconflow
# provider: buffered | openai_compatible | siliconflow | dashscope | volcengine
# dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-asr-flash-realtime
# note: dashscope uses streaming ASR mode (chunk-by-chunk).
# volcengine defaults (if omitted):
# api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
# model: bigmodel
# resource_id: volc.bigasr.sauc.duration
# app_id: your volcengine app key
# api_key: your volcengine access key
# request_params:
# end_window_size: 800
# force_to_speech_time: 1000
# note: volcengine uses streaming ASR mode (chunk-by-chunk).
provider: openai_compatible
api_key: your_asr_api_key
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
model: FunAudioLLM/SenseVoiceSmall
enable_interim: false
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160

View File

@@ -0,0 +1,68 @@
# Agent behavior configuration (safe to edit per profile)
# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers).
# Infra/server/network settings should stay in .env.
agent:
vad:
type: silero
model_path: data/vad/silero_vad.onnx
threshold: 0.5
min_speech_duration_ms: 100
eou_threshold_ms: 800
llm:
# provider: openai | openai_compatible | siliconflow
provider: openai_compatible
model: deepseek-v3
temperature: 0.7
# Required: no fallback. You can still reference env explicitly.
api_key: your_llm_api_key
# Optional for OpenAI-compatible endpoints:
api_url: https://api.qnaigc.com/v1
tts:
# provider: edge | openai_compatible | siliconflow | dashscope
# dashscope defaults (if omitted):
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
# model: qwen3-tts-flash-realtime
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
# note: dashscope_mode/mode is ONLY used when provider=dashscope.
# volcengine defaults (if omitted):
provider: volcengine
api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional
resource_id: seed-tts-2.0
app_id: your_tts_app_id
api_key: your_tts_api_key
speed: 1.1
voice: zh_female_vv_uranus_bigtts
asr:
asr:
provider: volcengine
api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
app_id: your_asr_app_id
api_key: your_asr_api_key
resource_id: volc.bigasr.sauc.duration
uid: caller-1
model: bigmodel
request_params:
end_window_size: 800
force_to_speech_time: 1000
enable_punc: true
enable_itn: false
enable_ddc: false
show_utterance: true
result_type: single
interim_interval_ms: 500
min_audio_ms: 300
start_min_speech_ms: 160
pre_speech_ms: 240
final_tail_ms: 120
duplex:
enabled: true
system_prompt: 你是一个人工智能助手你用简答语句回答避免使用标点符号和emoji。
barge_in:
min_duration_ms: 200
silence_tolerance_ms: 60

View File

@@ -20,7 +20,7 @@ This document defines the draft port set used to keep core runtime extensible.
- `runtime/ports/asr.py`
- `ASRServiceSpec`
- `ASRPort`
- optional extensions: `ASRInterimControl`, `ASRBufferControl`
- explicit mode ports: `OfflineASRPort`, `StreamingASRPort`
- `runtime/ports/service_factory.py`
- `RealtimeServiceFactory`
@@ -36,10 +36,10 @@ This document defines the draft port set used to keep core runtime extensible.
- supported providers: `openai`, `openai_compatible`, `openai-compatible`, `siliconflow`
- fallback: `MockLLMService`
- TTS:
- supported providers: `dashscope`, `openai_compatible`, `openai-compatible`, `siliconflow`
- supported providers: `dashscope`, `volcengine`, `openai_compatible`, `openai-compatible`, `siliconflow`
- fallback: `MockTTSService`
- ASR:
- supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`
- supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope`, `volcengine`
- fallback: `BufferedASRService`
## Notes

View File

@@ -1 +1,15 @@
"""ASR providers."""
from providers.asr.buffered import BufferedASRService, MockASRService
from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService
from providers.asr.volcengine import VolcengineRealtimeASRService
__all__ = [
"BufferedASRService",
"MockASRService",
"DashScopeRealtimeASRService",
"OpenAICompatibleASRService",
"SiliconFlowASRService",
"VolcengineRealtimeASRService",
]

View File

@@ -34,6 +34,7 @@ class BufferedASRService(BaseASRService):
language: str = "en"
):
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
self._audio_buffer: bytes = b""
self._current_text: str = ""
@@ -87,6 +88,23 @@ class BufferedASRService(BaseASRService):
self._audio_buffer = b""
return text
async def get_final_transcription(self) -> str:
"""Offline compatibility method used by DuplexPipeline."""
return self.get_and_clear_text()
def clear_buffer(self) -> None:
"""Offline compatibility method used by DuplexPipeline."""
self._audio_buffer = b""
self._current_text = ""
async def start_interim_transcription(self) -> None:
"""No-op for plain buffered ASR."""
return None
async def stop_interim_transcription(self) -> None:
"""No-op for plain buffered ASR."""
return None
def get_audio_buffer(self) -> bytes:
"""Get accumulated audio buffer."""
return self._audio_buffer
@@ -103,6 +121,7 @@ class MockASRService(BaseASRService):
def __init__(self, sample_rate: int = 16000, language: str = "en"):
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
self._mock_texts = [
"Hello, how are you?",
@@ -145,3 +164,18 @@ class MockASRService(BaseASRService):
continue
except asyncio.CancelledError:
break
def clear_buffer(self) -> None:
return None
async def get_final_transcription(self) -> str:
return ""
def get_and_clear_text(self) -> str:
return ""
async def start_interim_transcription(self) -> None:
return None
async def stop_interim_transcription(self) -> None:
return None

View File

@@ -0,0 +1,388 @@
"""DashScope realtime streaming ASR service.
Uses Qwen-ASR-Realtime via DashScope Python SDK.
"""
from __future__ import annotations
import asyncio
import base64
import json
import os
import sys
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional
from loguru import logger
from providers.common.base import ASRResult, BaseASRService, ServiceState
try:
import dashscope
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
# Some SDK builds keep TranscriptionParams under qwen_omni.omni_realtime.
try:
from dashscope.audio.qwen_omni import TranscriptionParams
except ImportError:
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
DASHSCOPE_SDK_AVAILABLE = True
DASHSCOPE_IMPORT_ERROR = ""
except Exception as exc:
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
dashscope = None # type: ignore[assignment]
MultiModality = None # type: ignore[assignment]
OmniRealtimeConversation = None # type: ignore[assignment]
TranscriptionParams = None # type: ignore[assignment]
DASHSCOPE_SDK_AVAILABLE = False
class OmniRealtimeCallback: # type: ignore[no-redef]
"""Fallback callback base when DashScope SDK is unavailable."""
pass
class _DashScopeASRCallback(OmniRealtimeCallback):
"""Bridge DashScope SDK callbacks into asyncio loop-safe handlers."""
def __init__(self, owner: "DashScopeRealtimeASRService", loop: asyncio.AbstractEventLoop):
super().__init__()
self._owner = owner
self._loop = loop
def _schedule(self, fn: Callable[[], None]) -> None:
try:
self._loop.call_soon_threadsafe(fn)
except RuntimeError:
return
def on_open(self) -> None:
self._schedule(self._owner._on_ws_open)
def on_close(self, code: int, msg: str) -> None:
self._schedule(lambda: self._owner._on_ws_close(code, msg))
def on_event(self, message: Any) -> None:
self._schedule(lambda: self._owner._on_ws_event(message))
def on_error(self, message: Any) -> None:
self._schedule(lambda: self._owner._on_ws_error(message))
class DashScopeRealtimeASRService(BaseASRService):
"""Realtime streaming ASR implementation for DashScope Qwen-ASR-Realtime."""
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
DEFAULT_MODEL = "qwen3-asr-flash-realtime"
DEFAULT_FINAL_TIMEOUT_MS = 800
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
model: Optional[str] = None,
sample_rate: int = 16000,
language: str = "auto",
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
) -> None:
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "streaming"
self.api_key = (
api_key
or os.getenv("DASHSCOPE_API_KEY")
or os.getenv("ASR_API_KEY")
)
self.api_url = api_url or os.getenv("DASHSCOPE_ASR_API_URL") or self.DEFAULT_WS_URL
self.model = model or os.getenv("DASHSCOPE_ASR_MODEL") or self.DEFAULT_MODEL
self.on_transcript = on_transcript
self._client: Optional[Any] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._callback: Optional[_DashScopeASRCallback] = None
self._running = False
self._session_ready = asyncio.Event()
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error: Optional[str] = None
async def connect(self) -> None:
if not DASHSCOPE_SDK_AVAILABLE:
py_exec = sys.executable
hint = f"`{py_exec} -m pip install dashscope>=1.25.6`"
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
raise RuntimeError(
f"dashscope SDK unavailable in interpreter {py_exec}; install with {hint}{detail}"
)
if not self.api_key:
raise ValueError("DashScope ASR API key not provided. Configure agent.asr.api_key in YAML.")
self._loop = asyncio.get_running_loop()
self._callback = _DashScopeASRCallback(owner=self, loop=self._loop)
if dashscope is not None:
dashscope.api_key = self.api_key
self._client = OmniRealtimeConversation( # type: ignore[misc]
model=self.model,
url=self.api_url,
callback=self._callback,
)
await asyncio.to_thread(self._client.connect)
await self._configure_session()
self._running = True
self.state = ServiceState.CONNECTED
logger.info(
"DashScope realtime ASR connected: model={}, sample_rate={}, language={}",
self.model,
self.sample_rate,
self.language,
)
async def disconnect(self) -> None:
self._running = False
self._utterance_active = False
self._audio_sent_in_utterance = False
self._drain_queue(self._final_queue)
self._drain_queue(self._transcript_queue)
self._session_ready.clear()
if self._client is not None:
close_fn = getattr(self._client, "close", None)
if callable(close_fn):
await asyncio.to_thread(close_fn)
self._client = None
self.state = ServiceState.DISCONNECTED
logger.info("DashScope realtime ASR disconnected")
async def begin_utterance(self) -> None:
self.clear_utterance()
self._utterance_active = True
async def send_audio(self, audio: bytes) -> None:
if not self._client:
raise RuntimeError("DashScope ASR service not connected")
if not audio:
return
if not self._utterance_active:
# Allow graceful fallback if caller sends before begin_utterance.
self._utterance_active = True
audio_b64 = base64.b64encode(audio).decode("ascii")
append_fn = getattr(self._client, "append_audio", None)
if not callable(append_fn):
raise RuntimeError("DashScope ASR SDK missing append_audio method")
await asyncio.to_thread(append_fn, audio_b64)
self._audio_sent_in_utterance = True
async def end_utterance(self) -> None:
if not self._client:
return
if not self._utterance_active or not self._audio_sent_in_utterance:
return
commit_fn = getattr(self._client, "commit", None)
if not callable(commit_fn):
raise RuntimeError("DashScope ASR SDK missing commit method")
await asyncio.to_thread(commit_fn)
self._utterance_active = False
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
if not self._audio_sent_in_utterance:
return ""
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
try:
text = await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec)
return str(text or "").strip()
except asyncio.TimeoutError:
logger.debug("DashScope ASR final timeout ({}ms), fallback to last interim", timeout_ms)
return str(self._last_interim_text or "").strip()
def clear_utterance(self) -> None:
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error = None
self._drain_queue(self._final_queue)
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
while self._running:
try:
result = await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
yield result
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
async def _configure_session(self) -> None:
if not self._client:
raise RuntimeError("DashScope ASR client is not initialized")
text_modality: Any = "text"
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
text_modality = MultiModality.TEXT
transcription_params: Optional[Any] = None
if TranscriptionParams is not None:
try:
lang = "zh" if self.language == "auto" else self.language
transcription_params = TranscriptionParams(
language=lang,
sample_rate=self.sample_rate,
input_audio_format="pcm",
)
except Exception as exc:
logger.debug("DashScope ASR TranscriptionParams init failed: {}", exc)
transcription_params = None
update_attempts = [
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
"transcription_params": transcription_params,
},
{
"output_modalities": [text_modality],
"enable_turn_detection": False,
"enable_input_audio_transcription": True,
},
{
"output_modalities": [text_modality],
},
]
update_fn = getattr(self._client, "update_session", None)
if not callable(update_fn):
raise RuntimeError("DashScope ASR SDK missing update_session method")
last_error: Optional[Exception] = None
for params in update_attempts:
if params.get("transcription_params") is None:
params = {k: v for k, v in params.items() if k != "transcription_params"}
try:
await asyncio.to_thread(update_fn, **params)
break
except TypeError as exc:
last_error = exc
continue
except Exception as exc:
last_error = exc
continue
else:
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
try:
await asyncio.wait_for(self._session_ready.wait(), timeout=6.0)
except asyncio.TimeoutError:
logger.debug("DashScope ASR session ready wait timeout; continuing")
def _on_ws_open(self) -> None:
return None
def _on_ws_close(self, code: int, msg: str) -> None:
self._last_error = f"DashScope ASR websocket closed: {code} {msg}"
logger.debug(self._last_error)
def _on_ws_error(self, message: Any) -> None:
self._last_error = str(message)
logger.error("DashScope ASR error: {}", self._last_error)
def _on_ws_event(self, message: Any) -> None:
payload = self._coerce_event(message)
event_type = str(payload.get("type") or "").strip()
if not event_type:
return
if event_type in {"session.created", "session.updated"}:
self._session_ready.set()
return
if event_type == "error" or event_type.endswith(".failed"):
err_text = self._extract_text(payload, keys=("message", "error", "details"))
self._last_error = err_text or event_type
logger.error("DashScope ASR server event error: {}", self._last_error)
return
if event_type == "conversation.item.input_audio_transcription.text":
stash_text = self._extract_text(payload, keys=("stash", "text", "transcript"))
self._emit_transcript(stash_text, is_final=False)
return
if event_type == "conversation.item.input_audio_transcription.completed":
final_text = self._extract_text(payload, keys=("transcript", "text", "stash"))
self._emit_transcript(final_text, is_final=True)
return
def _emit_transcript(self, text: str, *, is_final: bool) -> None:
normalized = str(text or "").strip()
if not normalized:
return
if not is_final and normalized == self._last_interim_text:
return
if not is_final:
self._last_interim_text = normalized
if self._loop is None:
return
try:
asyncio.run_coroutine_threadsafe(
self._publish_transcript(normalized, is_final=is_final),
self._loop,
)
except RuntimeError:
return
async def _publish_transcript(self, text: str, *, is_final: bool) -> None:
await self._transcript_queue.put(ASRResult(text=text, is_final=is_final))
if is_final:
await self._final_queue.put(text)
if self.on_transcript:
try:
await self.on_transcript(text, is_final)
except Exception as exc:
logger.warning("DashScope ASR transcript callback failed: {}", exc)
@staticmethod
def _coerce_event(message: Any) -> Dict[str, Any]:
if isinstance(message, dict):
return message
if isinstance(message, str):
try:
parsed = json.loads(message)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
return {"type": "raw", "text": message}
return {"type": "raw", "text": str(message)}
def _extract_text(self, payload: Dict[str, Any], *, keys: tuple[str, ...]) -> str:
for key in keys:
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, dict):
nested = self._extract_text(value, keys=keys)
if nested:
return nested
for value in payload.values():
if isinstance(value, dict):
nested = self._extract_text(value, keys=keys)
if nested:
return nested
return ""
@staticmethod
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
while True:
try:
queue.get_nowait()
except asyncio.QueueEmpty:
break

View File

@@ -53,6 +53,7 @@ class OpenAICompatibleASRService(BaseASRService):
model: str = "FunAudioLLM/SenseVoiceSmall",
sample_rate: int = 16000,
language: str = "auto",
enable_interim: bool = False,
interim_interval_ms: int = 500, # How often to send interim results
min_audio_for_interim_ms: int = 300, # Min audio before first interim
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None
@@ -66,11 +67,13 @@ class OpenAICompatibleASRService(BaseASRService):
model: ASR model name or alias
sample_rate: Audio sample rate (16000 recommended)
language: Language code (auto for automatic detection)
enable_interim: Whether to generate interim transcriptions in offline mode
interim_interval_ms: How often to generate interim transcriptions
min_audio_for_interim_ms: Minimum audio duration before first interim
on_transcript: Callback for transcription results (text, is_final)
"""
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "offline"
if not AIOHTTP_AVAILABLE:
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
@@ -79,6 +82,7 @@ class OpenAICompatibleASRService(BaseASRService):
raw_api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL
self.api_url = self._resolve_transcriptions_endpoint(raw_api_url)
self.model = self.MODELS.get(model.lower(), model)
self.enable_interim = bool(enable_interim)
self.interim_interval_ms = interim_interval_ms
self.min_audio_for_interim_ms = min_audio_for_interim_ms
self.on_transcript = on_transcript
@@ -181,6 +185,9 @@ class OpenAICompatibleASRService(BaseASRService):
logger.warning("ASR session not connected")
return None
if not is_final and not self.enable_interim:
return None
# Check minimum audio duration
audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000
@@ -309,6 +316,9 @@ class OpenAICompatibleASRService(BaseASRService):
This periodically transcribes buffered audio for
real-time feedback to the user.
"""
if not self.enable_interim:
return
if self._interim_task and not self._interim_task.done():
return

View File

@@ -0,0 +1,666 @@
"""Volcengine realtime ASR service.
Supports both:
- Volcengine Edge Gateway realtime transcription websocket, and
- Volcengine BigASR Seed websocket at openspeech.bytedance.com/api/v3/sauc/bigmodel.
"""
from __future__ import annotations
import asyncio
import base64
import gzip
import json
import os
import uuid
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
import aiohttp
from loguru import logger
from providers.common.base import ASRResult, BaseASRService, ServiceState
VolcengineASRProtocol = Literal["gateway", "seed"]
class VolcengineRealtimeASRService(BaseASRService):
"""Realtime streaming ASR backed by Volcengine websocket APIs."""
DEFAULT_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
DEFAULT_GATEWAY_WS_URL = "wss://ai-gateway.vei.volces.com/v1/realtime"
DEFAULT_MODEL = "bigmodel"
DEFAULT_FINAL_TIMEOUT_MS = 1200
DEFAULT_SEED_RESOURCE_ID = "volc.bigasr.sauc.duration"
_SEED_FRAME_MS = 100
_SEED_PROTOCOL_VERSION = 0b0001
_SEED_FULL_CLIENT_REQUEST = 0b0001
_SEED_AUDIO_ONLY_REQUEST = 0b0010
_SEED_FULL_SERVER_RESPONSE = 0b1001
_SEED_SERVER_ACK = 0b1011
_SEED_SERVER_ERROR_RESPONSE = 0b1111
_SEED_NO_SEQUENCE = 0b0000
_SEED_POS_SEQUENCE = 0b0001
_SEED_NEG_WITH_SEQUENCE = 0b0011
_SEED_NO_SERIALIZATION = 0b0000
_SEED_JSON = 0b0001
_SEED_NO_COMPRESSION = 0b0000
_SEED_GZIP = 0b0001
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
model: Optional[str] = None,
sample_rate: int = 16000,
language: str = "auto",
app_id: Optional[str] = None,
resource_id: Optional[str] = None,
uid: Optional[str] = None,
request_params: Optional[Dict[str, Any]] = None,
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
) -> None:
super().__init__(sample_rate=sample_rate, language=language)
self.mode = "streaming"
self.api_key = api_key or os.getenv("VOLCENGINE_ASR_API_KEY") or os.getenv("ASR_API_KEY")
self.model = str(model or os.getenv("VOLCENGINE_ASR_MODEL") or self.DEFAULT_MODEL).strip()
raw_api_url = api_url or os.getenv("VOLCENGINE_ASR_API_URL") or self.DEFAULT_WS_URL
self.protocol = self._detect_protocol(raw_api_url)
self.api_url = self._resolve_api_url(raw_api_url, self.model, self.protocol)
self.app_id = app_id or os.getenv("VOLCENGINE_ASR_APP_ID") or os.getenv("ASR_APP_ID")
self.resource_id = (
resource_id
or os.getenv("VOLCENGINE_ASR_RESOURCE_ID")
or (self.DEFAULT_SEED_RESOURCE_ID if self.protocol == "seed" else None)
)
self.uid = uid or os.getenv("VOLCENGINE_ASR_UID")
self.request_params = self._load_request_params(request_params)
self.on_transcript = on_transcript
self._session: Optional[aiohttp.ClientSession] = None
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._reader_task: Optional[asyncio.Task[None]] = None
self._running = False
self._session_ready = asyncio.Event()
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error: Optional[str] = None
self._seed_audio_buffer = bytearray()
self._seed_sequence = 1
self._seed_request_id: Optional[str] = None
self._seed_frame_bytes = max(2, int((self.sample_rate * self._SEED_FRAME_MS / 1000) * 2))
@classmethod
def _detect_protocol(cls, api_url: str) -> VolcengineASRProtocol:
parsed = urlparse(str(api_url or "").strip())
host = parsed.netloc.lower()
path = parsed.path.lower()
if "openspeech.bytedance.com" in host and "/api/v3/sauc/bigmodel" in path:
return "seed"
return "gateway"
@classmethod
def _resolve_api_url(cls, api_url: str, model: str, protocol: VolcengineASRProtocol) -> str:
raw = str(api_url or "").strip()
if not raw:
raw = cls.DEFAULT_WS_URL if protocol == "seed" else cls.DEFAULT_GATEWAY_WS_URL
if protocol != "gateway":
return raw
parsed = urlparse(raw)
query = dict(parse_qsl(parsed.query, keep_blank_values=True))
query.setdefault("model", model or cls.DEFAULT_MODEL)
return urlunparse(parsed._replace(query=urlencode(query)))
@staticmethod
def _load_request_params(request_params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(request_params, dict):
return dict(request_params)
raw = os.getenv("VOLCENGINE_ASR_REQUEST_PARAMS_JSON", "").strip()
if not raw:
return {}
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Ignoring invalid VOLCENGINE_ASR_REQUEST_PARAMS_JSON")
return {}
if isinstance(parsed, dict):
return parsed
return {}
async def connect(self) -> None:
if not self.api_key:
raise ValueError("Volcengine ASR API key not provided. Configure agent.asr.api_key in YAML.")
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
self._session = aiohttp.ClientSession(timeout=timeout)
self._running = True
if self.protocol == "gateway":
await self._connect_gateway()
logger.info(
"Volcengine gateway ASR connected: model={}, sample_rate={}, url={}",
self.model,
self.sample_rate,
self.api_url,
)
else:
if not self.app_id:
raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.")
logger.info(
"Volcengine BigASR Seed ready: model={}, sample_rate={}, resource_id={}",
self.model,
self.sample_rate,
self.resource_id,
)
self.state = ServiceState.CONNECTED
async def disconnect(self) -> None:
self._running = False
self._utterance_active = False
self._audio_sent_in_utterance = False
self._session_ready.clear()
self._seed_audio_buffer = bytearray()
self._drain_queue(self._final_queue)
self._drain_queue(self._transcript_queue)
await self._close_ws()
if self._session is not None:
await self._session.close()
self._session = None
self.state = ServiceState.DISCONNECTED
logger.info("Volcengine ASR disconnected")
async def begin_utterance(self) -> None:
self.clear_utterance()
if self.protocol == "seed":
await self._open_seed_stream()
self._utterance_active = True
async def send_audio(self, audio: bytes) -> None:
if not audio:
return
if self.protocol == "seed":
await self._send_seed_audio(audio)
return
if not self._ws:
raise RuntimeError("Volcengine ASR websocket is not connected")
if not self._utterance_active:
self._utterance_active = True
await self._ws.send_json(
{
"type": "input_audio_buffer.append",
"audio": base64.b64encode(audio).decode("ascii"),
}
)
self._audio_sent_in_utterance = True
async def end_utterance(self) -> None:
if not self._utterance_active:
return
if self.protocol == "seed":
await self._end_seed_utterance()
return
if not self._ws or not self._audio_sent_in_utterance:
return
await self._ws.send_json({"type": "input_audio_buffer.commit"})
self._utterance_active = False
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
if not self._audio_sent_in_utterance:
return ""
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
try:
return str(await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) or "").strip()
except asyncio.TimeoutError:
logger.debug("Volcengine ASR final timeout ({}ms), fallback to last interim", timeout_ms)
return str(self._last_interim_text or "").strip()
finally:
if self.protocol == "seed":
await self._close_ws()
def clear_utterance(self) -> None:
self._utterance_active = False
self._audio_sent_in_utterance = False
self._last_interim_text = ""
self._last_error = None
self._seed_audio_buffer = bytearray()
self._seed_sequence = 1
self._seed_request_id = None
self._drain_queue(self._final_queue)
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
while self._running:
try:
yield await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
async def _connect_gateway(self) -> None:
assert self._session is not None
headers = {"Authorization": f"Bearer {self.api_key}"}
if self.resource_id:
headers["X-Api-Resource-Id"] = self.resource_id
self._ws = await self._session.ws_connect(self.api_url, headers=headers, heartbeat=20)
self._reader_task = asyncio.create_task(self._reader_loop())
await self._configure_gateway_session()
async def _configure_gateway_session(self) -> None:
if not self._ws:
raise RuntimeError("Volcengine ASR websocket is not initialized")
session_payload: Dict[str, Any] = {
"input_audio_format": "pcm",
"input_audio_codec": "raw",
"input_audio_sample_rate": self.sample_rate,
"input_audio_bits": 16,
"input_audio_channel": 1,
"result_type": 0,
"input_audio_transcription": {
"model": self.model,
},
}
await self._ws.send_json(
{
"type": "transcription_session.update",
"session": session_payload,
}
)
try:
await asyncio.wait_for(self._session_ready.wait(), timeout=8.0)
except asyncio.TimeoutError as exc:
raise RuntimeError("Volcengine ASR session update timeout") from exc
async def _open_seed_stream(self) -> None:
if not self._session:
raise RuntimeError("Volcengine ASR session is not initialized")
await self._close_ws()
self._seed_request_id = uuid.uuid4().hex
headers = self._build_seed_headers(self._seed_request_id)
self._ws = await self._session.ws_connect(
self.api_url,
headers=headers,
heartbeat=20,
max_msg_size=1_000_000_000,
)
self._reader_task = asyncio.create_task(self._reader_loop())
await self._ws.send_bytes(self._build_seed_start_request())
async def _send_seed_audio(self, audio: bytes) -> None:
if not self._utterance_active:
await self.begin_utterance()
if not self._ws:
raise RuntimeError("Volcengine BigASR websocket is not connected")
self._seed_audio_buffer.extend(audio)
while len(self._seed_audio_buffer) >= self._seed_frame_bytes:
chunk = bytes(self._seed_audio_buffer[: self._seed_frame_bytes])
del self._seed_audio_buffer[: self._seed_frame_bytes]
self._seed_sequence += 1
await self._ws.send_bytes(self._build_seed_audio_request(chunk, sequence=self._seed_sequence))
self._audio_sent_in_utterance = True
async def _end_seed_utterance(self) -> None:
if not self._ws:
return
if not self._audio_sent_in_utterance and not self._seed_audio_buffer:
self._utterance_active = False
return
final_chunk = bytes(self._seed_audio_buffer)
self._seed_audio_buffer = bytearray()
self._seed_sequence += 1
await self._ws.send_bytes(
self._build_seed_audio_request(final_chunk, sequence=-self._seed_sequence, is_last=True)
)
self._audio_sent_in_utterance = True
self._utterance_active = False
async def _close_ws(self) -> None:
reader_task = self._reader_task
ws = self._ws
self._reader_task = None
self._ws = None
if reader_task:
reader_task.cancel()
try:
await reader_task
except asyncio.CancelledError:
pass
if ws is not None:
await ws.close()
async def _reader_loop(self) -> None:
ws = self._ws
if ws is None:
return
try:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
if self.protocol == "gateway":
self._handle_gateway_event(msg.data)
else:
self._handle_seed_text(msg.data)
continue
if msg.type == aiohttp.WSMsgType.BINARY:
if self.protocol == "seed":
self._handle_seed_binary(msg.data)
continue
if msg.type == aiohttp.WSMsgType.ERROR:
self._last_error = str(ws.exception())
logger.error("Volcengine ASR websocket error: {}", self._last_error)
break
if msg.type in {aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE}:
break
except asyncio.CancelledError:
raise
except Exception as exc:
self._last_error = str(exc)
logger.error("Volcengine ASR reader loop failed: {}", exc)
finally:
if self._ws is ws:
self._ws = None
def _handle_gateway_event(self, message: str) -> None:
payload = self._coerce_event(message)
event_type = str(payload.get("type") or "").strip()
if not event_type:
return
if event_type in {"transcription_session.created", "transcription_session.updated"}:
self._session_ready.set()
return
if event_type == "error":
self._last_error = self._extract_text(payload, ("message", "error"))
logger.error("Volcengine ASR server error: {}", self._last_error or "unknown")
return
if event_type.endswith(".failed"):
self._last_error = self._extract_text(payload, ("message", "error", "transcript"))
logger.error("Volcengine ASR failed event: {}", self._last_error or event_type)
return
if event_type == "conversation.item.input_audio_transcription.result":
transcript = self._extract_text(payload, ("transcript", "result"))
self._emit_transcript_sync(transcript, is_final=False)
return
if event_type == "conversation.item.input_audio_transcription.delta":
transcript = self._extract_text(payload, ("delta",))
self._emit_transcript_sync(transcript, is_final=False)
return
if event_type == "conversation.item.input_audio_transcription.completed":
transcript = self._extract_text(payload, ("transcript", "result"))
self._emit_transcript_sync(transcript, is_final=True)
def _handle_seed_text(self, message: str) -> None:
payload = self._coerce_event(message)
if payload.get("type") == "error":
self._last_error = self._extract_text(payload, ("message", "error"))
logger.error("Volcengine BigASR error: {}", self._last_error or "unknown")
def _handle_seed_binary(self, message: bytes) -> None:
payload = self._parse_seed_response(message)
if payload.get("code"):
self._last_error = self._extract_text(payload, ("payload_msg",))
logger.error("Volcengine BigASR server error: {}", self._last_error or payload["code"])
return
body = payload.get("payload_msg")
if not isinstance(body, dict):
return
result = body.get("result")
if not isinstance(result, dict):
return
text = str(result.get("text") or "").strip()
if not text:
return
utterances = result.get("utterances")
if not isinstance(utterances, list) or not utterances:
return
first_utterance = utterances[0] if isinstance(utterances[0], dict) else {}
is_final = self._coerce_bool(first_utterance.get("definite")) is True
self._emit_transcript_sync(text, is_final=is_final)
def _emit_transcript_sync(self, text: str, *, is_final: bool) -> None:
cleaned = str(text or "").strip()
if not cleaned:
return
if not is_final:
self._last_interim_text = cleaned
else:
self._last_interim_text = ""
result = ASRResult(text=cleaned, is_final=is_final)
try:
self._transcript_queue.put_nowait(result)
except asyncio.QueueFull:
logger.debug("Volcengine ASR transcript queue full; dropping transcript")
if is_final:
try:
self._final_queue.put_nowait(cleaned)
except asyncio.QueueFull:
logger.debug("Volcengine ASR final queue full; dropping transcript")
if self.on_transcript:
asyncio.create_task(self.on_transcript(cleaned, is_final))
def _build_seed_headers(self, request_id: str) -> Dict[str, str]:
if not self.app_id:
raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.")
if not self.api_key:
raise ValueError("Volcengine ASR api_key not provided. Configure agent.asr.api_key in YAML.")
return {
"X-Api-App-Key": str(self.app_id),
"X-Api-Access-Key": str(self.api_key),
"X-Api-Resource-Id": str(self.resource_id or self.DEFAULT_SEED_RESOURCE_ID),
"X-Api-Request-Id": str(request_id),
}
def _build_seed_start_payload(self) -> Dict[str, Any]:
user_payload: Dict[str, Any] = {"uid": str(self.uid or self._seed_request_id or self.app_id or uuid.uuid4().hex)}
audio_payload: Dict[str, Any] = {
"format": "pcm",
"rate": self.sample_rate,
"bits": 16,
"channels": 1,
"codec": "raw",
}
if self.language and self.language != "auto":
audio_payload["language"] = self.language
request_payload: Dict[str, Any] = {
"model_name": self.model or self.DEFAULT_MODEL,
"enable_itn": False,
"enable_punc": True,
"enable_ddc": False,
"show_utterance": True,
"result_type": "single",
"vad_segment_duration": 3000,
"end_window_size": 500,
"force_to_speech_time": 1000,
}
extra = dict(self.request_params)
user_payload.update(self._as_dict(extra.pop("user", None)))
audio_payload.update(self._as_dict(extra.pop("audio", None)))
request_payload.update(self._as_dict(extra.pop("request", None)))
request_payload.update(extra)
return {
"user": user_payload,
"audio": audio_payload,
"request": request_payload,
}
def _build_seed_start_request(self) -> bytes:
payload = gzip.compress(json.dumps(self._build_seed_start_payload()).encode("utf-8"))
frame = bytearray(
self._build_seed_header(
message_type=self._SEED_FULL_CLIENT_REQUEST,
message_type_specific_flags=self._SEED_POS_SEQUENCE,
)
)
frame.extend((1).to_bytes(4, "big", signed=True))
frame.extend(len(payload).to_bytes(4, "big"))
frame.extend(payload)
return bytes(frame)
def _build_seed_audio_request(self, chunk: bytes, *, sequence: int, is_last: bool = False) -> bytes:
payload = gzip.compress(chunk)
frame = bytearray(
self._build_seed_header(
message_type=self._SEED_AUDIO_ONLY_REQUEST,
message_type_specific_flags=self._SEED_NEG_WITH_SEQUENCE if is_last else self._SEED_POS_SEQUENCE,
)
)
frame.extend(int(sequence).to_bytes(4, "big", signed=True))
frame.extend(len(payload).to_bytes(4, "big"))
frame.extend(payload)
return bytes(frame)
@classmethod
def _build_seed_header(
cls,
*,
message_type: int,
message_type_specific_flags: int,
serial_method: int = _SEED_JSON,
compression_type: int = _SEED_GZIP,
reserved_data: int = 0x00,
) -> bytes:
header = bytearray()
header.append((cls._SEED_PROTOCOL_VERSION << 4) | 0b0001)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
return bytes(header)
@classmethod
def _parse_seed_response(cls, response: bytes) -> Dict[str, Any]:
header_size = response[0] & 0x0F
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0F
serialization_method = response[2] >> 4
compression_type = response[2] & 0x0F
payload = response[header_size * 4 :]
result: Dict[str, Any] = {"is_last_package": False}
payload_message: Any = None
if message_type_specific_flags & 0x01:
result["payload_sequence"] = int.from_bytes(payload[:4], "big", signed=True)
payload = payload[4:]
if message_type_specific_flags & 0x02:
result["is_last_package"] = True
if message_type == cls._SEED_FULL_SERVER_RESPONSE:
result["payload_size"] = int.from_bytes(payload[:4], "big", signed=True)
payload_message = payload[4:]
elif message_type == cls._SEED_SERVER_ACK:
result["seq"] = int.from_bytes(payload[:4], "big", signed=True)
if len(payload) >= 8:
result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False)
payload_message = payload[8:]
elif message_type == cls._SEED_SERVER_ERROR_RESPONSE:
result["code"] = int.from_bytes(payload[:4], "big", signed=False)
result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False)
payload_message = payload[8:]
if payload_message is None:
return result
if compression_type == cls._SEED_GZIP:
payload_message = gzip.decompress(payload_message)
if serialization_method == cls._SEED_JSON:
payload_message = json.loads(payload_message.decode("utf-8"))
elif serialization_method != cls._SEED_NO_SERIALIZATION:
payload_message = payload_message.decode("utf-8")
result["payload_msg"] = payload_message
return result
@staticmethod
def _coerce_event(message: Any) -> Dict[str, Any]:
if isinstance(message, dict):
return message
if isinstance(message, str):
try:
loaded = json.loads(message)
if isinstance(loaded, dict):
return loaded
except json.JSONDecodeError:
return {"type": "raw", "message": message}
return {"type": "raw", "message": str(message)}
@staticmethod
def _extract_text(payload: Dict[str, Any], keys: tuple[str, ...]) -> str:
for key in keys:
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, dict):
for nested_key in ("message", "text", "transcript", "result", "delta"):
nested = value.get(nested_key)
if isinstance(nested, str) and nested.strip():
return nested.strip()
return ""
@staticmethod
def _coerce_bool(value: Any) -> Optional[bool]:
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return None
@staticmethod
def _as_dict(value: Any) -> Dict[str, Any]:
if isinstance(value, dict):
return dict(value)
return {}
@staticmethod
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
while True:
try:
queue.get_nowait()
except asyncio.QueueEmpty:
break

View File

@@ -16,13 +16,18 @@ from runtime.ports import (
TTSServiceSpec,
)
from providers.asr.buffered import BufferedASRService
from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.volcengine import VolcengineRealtimeASRService
from providers.tts.dashscope import DashScopeTTSService
from providers.llm.openai import MockLLMService, OpenAILLMService
from providers.asr.openai_compatible import OpenAICompatibleASRService
from providers.tts.openai_compatible import OpenAICompatibleTTSService
from providers.tts.mock import MockTTSService
from providers.tts.volcengine import VolcengineTTSService
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
_DASHSCOPE_PROVIDERS = {"dashscope"}
_VOLCENGINE_PROVIDERS = {"volcengine"}
_SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS}
@@ -31,8 +36,14 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
_DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
_DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime"
_DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
_DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime"
_DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
_DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
_DEFAULT_VOLCENGINE_TTS_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
_DEFAULT_VOLCENGINE_TTS_RESOURCE_ID = "seed-tts-2.0"
_DEFAULT_VOLCENGINE_ASR_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
_DEFAULT_VOLCENGINE_ASR_MODEL = "bigmodel"
@staticmethod
def _normalize_provider(provider: Any) -> str:
@@ -77,6 +88,19 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
speed=spec.speed,
)
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
return VolcengineTTSService(
api_key=spec.api_key,
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_TTS_URL,
voice=spec.voice,
model=spec.model,
app_id=spec.app_id,
resource_id=spec.resource_id or self._DEFAULT_VOLCENGINE_TTS_RESOURCE_ID,
uid=spec.uid,
sample_rate=spec.sample_rate,
speed=spec.speed,
)
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
return OpenAICompatibleTTSService(
api_key=spec.api_key,
@@ -96,6 +120,30 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort:
provider = self._normalize_provider(spec.provider)
if provider in _DASHSCOPE_PROVIDERS and spec.api_key:
return DashScopeRealtimeASRService(
api_key=spec.api_key,
api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL,
model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL,
sample_rate=spec.sample_rate,
language=spec.language,
on_transcript=spec.on_transcript,
)
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
return VolcengineRealtimeASRService(
api_key=spec.api_key,
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_ASR_REALTIME_URL,
model=spec.model or self._DEFAULT_VOLCENGINE_ASR_MODEL,
sample_rate=spec.sample_rate,
language=spec.language,
app_id=spec.app_id,
resource_id=spec.resource_id,
uid=spec.uid,
request_params=spec.request_params,
on_transcript=spec.on_transcript,
)
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
return OpenAICompatibleASRService(
api_key=spec.api_key,
@@ -103,6 +151,7 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL,
sample_rate=spec.sample_rate,
language=spec.language,
enable_interim=spec.enable_interim,
interim_interval_ms=spec.interim_interval_ms,
min_audio_for_interim_ms=spec.min_audio_for_interim_ms,
on_transcript=spec.on_transcript,

View File

@@ -1 +1,5 @@
"""TTS providers."""
from providers.tts.volcengine import VolcengineTTSService
__all__ = ["VolcengineTTSService"]

View File

@@ -0,0 +1,219 @@
"""Volcengine TTS service.
Uses Volcengine's unidirectional HTTP streaming TTS API and adapts streamed
base64 audio chunks into engine-native ``TTSChunk`` events.
"""
from __future__ import annotations
import asyncio
import base64
import codecs
import json
import os
import uuid
from typing import Any, AsyncIterator, Optional
import aiohttp
from loguru import logger
from providers.common.base import BaseTTSService, ServiceState, TTSChunk
class VolcengineTTSService(BaseTTSService):
"""Streaming TTS adapter for Volcengine's HTTP v3 API."""
DEFAULT_API_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
DEFAULT_RESOURCE_ID = "seed-tts-2.0"
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
voice: str = "zh_female_shuangkuaisisi_moon_bigtts",
model: Optional[str] = None,
app_id: Optional[str] = None,
resource_id: Optional[str] = None,
uid: Optional[str] = None,
sample_rate: int = 16000,
speed: float = 1.0,
) -> None:
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
self.api_key = api_key or os.getenv("VOLCENGINE_TTS_API_KEY") or os.getenv("TTS_API_KEY")
self.api_url = api_url or os.getenv("VOLCENGINE_TTS_API_URL") or self.DEFAULT_API_URL
self.model = str(model or os.getenv("VOLCENGINE_TTS_MODEL") or "").strip() or None
self.app_id = app_id or os.getenv("VOLCENGINE_TTS_APP_ID") or os.getenv("TTS_APP_ID")
self.resource_id = resource_id or os.getenv("VOLCENGINE_TTS_RESOURCE_ID") or self.DEFAULT_RESOURCE_ID
self.uid = uid or os.getenv("VOLCENGINE_TTS_UID")
self._session: Optional[aiohttp.ClientSession] = None
self._cancel_event = asyncio.Event()
self._synthesis_lock = asyncio.Lock()
self._pending_audio: list[bytes] = []
async def connect(self) -> None:
if not self.api_key:
raise ValueError("Volcengine TTS API key not provided. Configure agent.tts.api_key in YAML.")
if not self.app_id:
raise ValueError("Volcengine TTS app_id not provided. Configure agent.tts.app_id in YAML.")
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
self._session = aiohttp.ClientSession(timeout=timeout)
self.state = ServiceState.CONNECTED
logger.info(
"Volcengine TTS service ready: speaker={}, sample_rate={}, resource_id={}",
self.voice,
self.sample_rate,
self.resource_id,
)
async def disconnect(self) -> None:
self._cancel_event.set()
if self._session is not None:
await self._session.close()
self._session = None
self.state = ServiceState.DISCONNECTED
logger.info("Volcengine TTS service disconnected")
async def synthesize(self, text: str) -> bytes:
audio = b""
async for chunk in self.synthesize_stream(text):
audio += chunk.audio
return audio
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
if not self._session:
raise RuntimeError("Volcengine TTS service not connected")
if not text.strip():
return
async with self._synthesis_lock:
self._cancel_event.clear()
headers = {
"Content-Type": "application/json",
"X-Api-App-Key": str(self.app_id),
"X-Api-Access-Key": str(self.api_key),
"X-Api-Resource-Id": str(self.resource_id),
"X-Api-Request-Id": str(uuid.uuid4()),
}
payload = {
"user": {
"uid": str(self.uid or self.app_id),
},
"req_params": {
"text": text,
"speaker": self.voice,
"audio_params": {
"format": "pcm",
"sample_rate": self.sample_rate,
"speech_rate": self._speech_rate_percent(self.speed),
},
},
}
if self.model:
payload["req_params"]["model"] = self.model
chunk_size = max(1, self.sample_rate * 2 // 10)
audio_buffer = b""
pending_chunk: Optional[bytes] = None
try:
async with self._session.post(self.api_url, headers=headers, json=payload) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Volcengine TTS error {response.status}: {error_text}")
async for audio_bytes in self._iter_audio_bytes(response):
if self._cancel_event.is_set():
logger.info("Volcengine TTS synthesis cancelled")
return
audio_buffer += audio_bytes
while len(audio_buffer) >= chunk_size:
emitted = audio_buffer[:chunk_size]
audio_buffer = audio_buffer[chunk_size:]
if pending_chunk is not None:
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
pending_chunk = emitted
if self._cancel_event.is_set():
return
if pending_chunk is not None:
if audio_buffer:
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
pending_chunk = None
else:
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
pending_chunk = None
if audio_buffer:
yield TTSChunk(audio=audio_buffer, sample_rate=self.sample_rate, is_final=True)
except asyncio.CancelledError:
logger.info("Volcengine TTS synthesis cancelled via asyncio")
raise
except Exception as exc:
logger.error("Volcengine TTS synthesis error: {}", exc)
raise
async def cancel(self) -> None:
self._cancel_event.set()
async def _iter_audio_bytes(self, response: aiohttp.ClientResponse) -> AsyncIterator[bytes]:
decoder = json.JSONDecoder()
utf8_decoder = codecs.getincrementaldecoder("utf-8")()
text_buffer = ""
self._pending_audio.clear()
async for raw_chunk in response.content.iter_any():
text_buffer += utf8_decoder.decode(raw_chunk)
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
while self._pending_audio:
yield self._pending_audio.pop(0)
text_buffer += utf8_decoder.decode(b"", final=True)
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
while self._pending_audio:
yield self._pending_audio.pop(0)
def _yield_audio_payloads(self, decoder: json.JSONDecoder, text_buffer: str) -> str:
while True:
stripped = text_buffer.lstrip()
if not stripped:
return ""
if len(stripped) != len(text_buffer):
text_buffer = stripped
try:
payload, idx = decoder.raw_decode(text_buffer)
except json.JSONDecodeError:
return text_buffer
text_buffer = text_buffer[idx:]
audio = self._extract_audio_bytes(payload)
if audio:
self._pending_audio.append(audio)
def _extract_audio_bytes(self, payload: Any) -> bytes:
if not isinstance(payload, dict):
return b""
code = payload.get("code")
if code not in (None, 0, 20000000):
message = str(payload.get("message") or "unknown error")
raise RuntimeError(f"Volcengine TTS stream error {code}: {message}")
encoded = payload.get("data")
if isinstance(encoded, str) and encoded.strip():
try:
return base64.b64decode(encoded)
except Exception as exc:
logger.warning("Failed to decode Volcengine TTS audio chunk: {}", exc)
return b""
@staticmethod
def _speech_rate_percent(speed: float) -> int:
clamped = max(0.5, min(2.0, float(speed or 1.0)))
return int(round((clamped - 1.0) * 100))

View File

@@ -30,11 +30,14 @@ from providers.factory.default import DefaultRealtimeServiceFactory
from runtime.conversation import ConversationManager, ConversationState
from runtime.events import get_event_bus
from runtime.ports import (
ASRMode,
ASRPort,
ASRServiceSpec,
LLMPort,
LLMServiceSpec,
OfflineASRPort,
RealtimeServiceFactory,
StreamingASRPort,
TTSPort,
TTSServiceSpec,
)
@@ -77,6 +80,7 @@ class DuplexPipeline:
_ASR_DELTA_THROTTLE_MS = 500
_LLM_DELTA_THROTTLE_MS = 80
_ASR_CAPTURE_MAX_MS = 15000
_ASR_STREAM_FINAL_TIMEOUT_MS = 800
_OPENER_PRE_ROLL_MS = 180
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"current_time": {
@@ -317,6 +321,10 @@ class DuplexPipeline:
self.llm_service = llm_service
self.tts_service = tts_service
self.asr_service = asr_service # Will be initialized in start()
self._asr_mode: ASRMode = self._resolve_asr_mode(
settings.asr_provider,
getattr(asr_service, "mode", None),
)
self._service_factory = service_factory or DefaultRealtimeServiceFactory()
self._knowledge_searcher = knowledge_searcher
self._tool_resource_resolver = tool_resource_resolver
@@ -324,6 +332,7 @@ class DuplexPipeline:
# Track last sent transcript to avoid duplicates
self._last_sent_transcript = ""
self._latest_asr_interim_text = ""
self._pending_transcript_delta: str = ""
self._last_transcript_delta_emit_ms: float = 0.0
@@ -588,7 +597,9 @@ class DuplexPipeline:
},
"asr": {
"provider": asr_provider,
"mode": self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")),
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
"enableInterim": self._asr_interim_enabled(),
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
},
@@ -782,11 +793,44 @@ class DuplexPipeline:
return False
return None
@staticmethod
def _coerce_json_object(value: Any) -> Optional[Dict[str, Any]]:
if isinstance(value, dict):
return dict(value)
if isinstance(value, str):
raw = value.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Ignoring invalid JSON object config: {}", raw[:120])
return None
if isinstance(parsed, dict):
return parsed
return None
@staticmethod
def _is_dashscope_tts_provider(provider: Any) -> bool:
normalized = str(provider or "").strip().lower()
return normalized == "dashscope"
@staticmethod
def _resolve_asr_mode(provider: Any, raw_mode: Any = None) -> ASRMode:
normalized_mode = str(raw_mode or "").strip().lower()
if normalized_mode in {"offline", "streaming"}:
return normalized_mode # type: ignore[return-value]
normalized_provider = str(provider or "").strip().lower()
if normalized_provider in {"dashscope", "volcengine"}:
return "streaming"
return "offline"
def _offline_asr(self) -> OfflineASRPort:
return self.asr_service # type: ignore[return-value]
def _streaming_asr(self) -> StreamingASRPort:
return self.asr_service # type: ignore[return-value]
@staticmethod
def _default_llm_base_url(provider: Any) -> Optional[str]:
normalized = str(provider or "").strip().lower()
@@ -839,6 +883,20 @@ class DuplexPipeline:
return self._runtime_barge_in_min_duration_ms
return self._barge_in_min_duration_ms
def _asr_interim_enabled(self) -> bool:
current_mode = self._asr_mode
if not self.asr_service:
current_mode = self._resolve_asr_mode(
self._runtime_asr.get("provider") or settings.asr_provider,
self._runtime_asr.get("mode"),
)
if current_mode != "offline":
return True
enabled = self._coerce_bool(self._runtime_asr.get("enableInterim"))
if enabled is not None:
return enabled
return bool(settings.asr_enable_interim)
def _barge_in_silence_tolerance_frames(self) -> int:
"""Convert silence tolerance from ms to frame count using current chunk size."""
chunk_ms = max(1, settings.chunk_size_ms)
@@ -922,6 +980,10 @@ class DuplexPipeline:
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
tts_model = self._runtime_tts.get("model") or settings.tts_model
tts_app_id = self._runtime_tts.get("appId") or settings.tts_app_id
tts_resource_id = self._runtime_tts.get("resourceId") or settings.tts_resource_id
tts_cluster = self._runtime_tts.get("cluster") or settings.tts_cluster
tts_uid = self._runtime_tts.get("uid") or settings.tts_uid
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
tts_mode = self._resolved_dashscope_tts_mode()
runtime_mode = str(self._runtime_tts.get("mode") or "").strip()
@@ -937,6 +999,10 @@ class DuplexPipeline:
api_url=str(tts_api_url).strip() if tts_api_url else None,
voice=str(tts_voice),
model=str(tts_model).strip() if tts_model else None,
app_id=str(tts_app_id).strip() if tts_app_id else None,
resource_id=str(tts_resource_id).strip() if tts_resource_id else None,
cluster=str(tts_cluster).strip() if tts_cluster else None,
uid=str(tts_uid).strip() if tts_uid else None,
sample_rate=settings.sample_rate,
speed=tts_speed,
mode=str(tts_mode),
@@ -965,26 +1031,48 @@ class DuplexPipeline:
asr_api_key = self._runtime_asr.get("apiKey")
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
asr_model = self._runtime_asr.get("model") or settings.asr_model
asr_app_id = self._runtime_asr.get("appId") or settings.asr_app_id
asr_resource_id = self._runtime_asr.get("resourceId") or settings.asr_resource_id
asr_cluster = self._runtime_asr.get("cluster") or settings.asr_cluster
asr_uid = self._runtime_asr.get("uid") or settings.asr_uid
asr_request_params = self._coerce_json_object(self._runtime_asr.get("requestParams"))
if asr_request_params is None:
asr_request_params = self._coerce_json_object(settings.asr_request_params_json)
asr_enable_interim = self._coerce_bool(self._runtime_asr.get("enableInterim"))
if asr_enable_interim is None:
asr_enable_interim = bool(settings.asr_enable_interim)
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
asr_mode = self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode"))
self.asr_service = self._service_factory.create_asr_service(
ASRServiceSpec(
provider=asr_provider,
sample_rate=settings.sample_rate,
mode=asr_mode,
language="auto",
api_key=str(asr_api_key).strip() if asr_api_key else None,
api_url=str(asr_api_url).strip() if asr_api_url else None,
model=str(asr_model).strip() if asr_model else None,
app_id=str(asr_app_id).strip() if asr_app_id else None,
resource_id=str(asr_resource_id).strip() if asr_resource_id else None,
cluster=str(asr_cluster).strip() if asr_cluster else None,
uid=str(asr_uid).strip() if asr_uid else None,
request_params=asr_request_params,
enable_interim=asr_enable_interim,
interim_interval_ms=asr_interim_interval,
min_audio_for_interim_ms=asr_min_audio_ms,
on_transcript=self._on_transcript_callback,
)
)
self._asr_mode = self._resolve_asr_mode(
self._runtime_asr.get("provider") or settings.asr_provider,
getattr(self.asr_service, "mode", self._runtime_asr.get("mode")),
)
await self.asr_service.connect()
logger.info("DuplexPipeline services connected")
logger.info("DuplexPipeline services connected (asr_mode={})", self._asr_mode)
if not self._outbound_task or self._outbound_task.done():
self._outbound_task = asyncio.create_task(self._outbound_loop())
@@ -1449,6 +1537,9 @@ class DuplexPipeline:
text: Transcribed text
is_final: Whether this is the final transcription
"""
if not is_final and not self._asr_interim_enabled():
return
# Avoid sending duplicate transcripts
if text == self._last_sent_transcript and not is_final:
return
@@ -1457,6 +1548,7 @@ class DuplexPipeline:
self._last_sent_transcript = text
if is_final:
self._latest_asr_interim_text = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
await self._send_event(
@@ -1472,6 +1564,7 @@ class DuplexPipeline:
logger.debug(f"Sent transcript (final): {text[:50]}...")
return
self._latest_asr_interim_text = text
self._pending_transcript_delta = text
should_emit = (
self._last_transcript_delta_emit_ms <= 0.0
@@ -1495,14 +1588,16 @@ class DuplexPipeline:
await self.conversation.start_user_turn()
self._audio_buffer = b""
self._last_sent_transcript = ""
self._latest_asr_interim_text = ""
self.eou_detector.reset()
self._asr_capture_active = False
self._asr_capture_started_ms = 0.0
self._pending_speech_audio = b""
# Clear ASR buffer. Interim starts only after ASR capture is activated.
if hasattr(self.asr_service, 'clear_buffer'):
self.asr_service.clear_buffer()
if self._asr_mode == "streaming":
self._streaming_asr().clear_utterance()
else:
self._offline_asr().clear_buffer()
logger.debug("User speech started")
@@ -1511,8 +1606,11 @@ class DuplexPipeline:
if self._asr_capture_active:
return
if hasattr(self.asr_service, 'start_interim_transcription'):
await self.asr_service.start_interim_transcription()
if self._asr_mode == "streaming":
await self._streaming_asr().begin_utterance()
else:
if self._asr_interim_enabled():
await self._offline_asr().start_interim_transcription()
# Prime ASR with a short pre-speech context window so the utterance
# start isn't lost while waiting for VAD to transition to Speech.
@@ -1545,24 +1643,22 @@ class DuplexPipeline:
self._pending_speech_audio = b""
return
# Add a tiny trailing silence tail to stabilize final-token decoding.
if self._asr_final_tail_bytes > 0:
final_tail = b"\x00" * self._asr_final_tail_bytes
await self.asr_service.send_audio(final_tail)
# Stop interim transcriptions
if hasattr(self.asr_service, 'stop_interim_transcription'):
await self.asr_service.stop_interim_transcription()
# Get final transcription from ASR service
user_text = ""
if hasattr(self.asr_service, 'get_final_transcription'):
# SiliconFlow ASR - get final transcription
user_text = await self.asr_service.get_final_transcription()
elif hasattr(self.asr_service, 'get_and_clear_text'):
# Buffered ASR - get accumulated text
user_text = self.asr_service.get_and_clear_text()
if self._asr_mode == "streaming":
streaming_asr = self._streaming_asr()
await streaming_asr.end_utterance()
user_text = await streaming_asr.wait_for_final_transcription(
timeout_ms=self._ASR_STREAM_FINAL_TIMEOUT_MS
)
if not user_text.strip():
user_text = self._latest_asr_interim_text
else:
# Add a tiny trailing silence tail to stabilize final-token decoding.
if self._asr_final_tail_bytes > 0:
final_tail = b"\x00" * self._asr_final_tail_bytes
await self.asr_service.send_audio(final_tail)
await self._offline_asr().stop_interim_transcription()
user_text = await self._offline_asr().get_final_transcription()
# Skip if no meaningful text
if not user_text or not user_text.strip():
@@ -1570,6 +1666,7 @@ class DuplexPipeline:
# Reset for next utterance
self._audio_buffer = b""
self._last_sent_transcript = ""
self._latest_asr_interim_text = ""
self._asr_capture_active = False
self._asr_capture_started_ms = 0.0
self._pending_speech_audio = b""
@@ -1594,6 +1691,7 @@ class DuplexPipeline:
# Clear buffers
self._audio_buffer = b""
self._last_sent_transcript = ""
self._latest_asr_interim_text = ""
self._pending_transcript_delta = ""
self._last_transcript_delta_emit_ms = 0.0
self._asr_capture_active = False

View File

@@ -1,6 +1,12 @@
"""Port interfaces for runtime integration boundaries."""
from runtime.ports.asr import ASRBufferControl, ASRInterimControl, ASRPort, ASRServiceSpec
from runtime.ports.asr import (
ASRMode,
ASRPort,
ASRServiceSpec,
OfflineASRPort,
StreamingASRPort,
)
from runtime.ports.control_plane import (
AssistantRuntimeConfigProvider,
ControlPlaneGateway,
@@ -13,10 +19,11 @@ from runtime.ports.service_factory import RealtimeServiceFactory
from runtime.ports.tts import TTSPort, TTSServiceSpec
__all__ = [
"ASRMode",
"ASRPort",
"ASRServiceSpec",
"ASRInterimControl",
"ASRBufferControl",
"OfflineASRPort",
"StreamingASRPort",
"AssistantRuntimeConfigProvider",
"ControlPlaneGateway",
"ConversationHistoryStore",

View File

@@ -3,11 +3,12 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import AsyncIterator, Awaitable, Callable, Optional, Protocol
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional, Protocol
from providers.common.base import ASRResult
TranscriptCallback = Callable[[str, bool], Awaitable[None]]
ASRMode = Literal["offline", "streaming"]
@dataclass(frozen=True)
@@ -16,10 +17,17 @@ class ASRServiceSpec:
provider: str
sample_rate: int
mode: Optional[ASRMode] = None
language: str = "auto"
api_key: Optional[str] = None
api_url: Optional[str] = None
model: Optional[str] = None
app_id: Optional[str] = None
resource_id: Optional[str] = None
cluster: Optional[str] = None
uid: Optional[str] = None
request_params: Optional[Dict[str, Any]] = None
enable_interim: bool = False
interim_interval_ms: int = 500
min_audio_for_interim_ms: int = 300
on_transcript: Optional[TranscriptCallback] = None
@@ -28,6 +36,8 @@ class ASRServiceSpec:
class ASRPort(Protocol):
"""Port for speech recognition providers."""
mode: ASRMode
async def connect(self) -> None:
"""Establish connection to ASR provider."""
@@ -41,18 +51,16 @@ class ASRPort(Protocol):
"""Stream partial/final recognition results."""
class ASRInterimControl(Protocol):
"""Optional extension for explicit interim transcription control."""
class OfflineASRPort(ASRPort, Protocol):
"""Port for offline/buffered ASR providers."""
mode: Literal["offline"]
async def start_interim_transcription(self) -> None:
"""Start interim transcription loop if supported."""
"""Start interim transcription loop."""
async def stop_interim_transcription(self) -> None:
"""Stop interim transcription loop if supported."""
class ASRBufferControl(Protocol):
"""Optional extension for explicit ASR buffer lifecycle control."""
"""Stop interim transcription loop."""
def clear_buffer(self) -> None:
"""Clear provider-side ASR buffer."""
@@ -62,3 +70,21 @@ class ASRBufferControl(Protocol):
def get_and_clear_text(self) -> str:
"""Return buffered text and clear internal state."""
class StreamingASRPort(ASRPort, Protocol):
"""Port for streaming ASR providers."""
mode: Literal["streaming"]
async def begin_utterance(self) -> None:
"""Start a new utterance stream."""
async def end_utterance(self) -> None:
"""Signal end of current utterance stream."""
async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str:
"""Wait for final transcript after utterance end."""
def clear_utterance(self) -> None:
"""Reset utterance-local state."""

View File

@@ -19,6 +19,10 @@ class TTSServiceSpec:
api_key: Optional[str] = None
api_url: Optional[str] = None
model: Optional[str] = None
app_id: Optional[str] = None
resource_id: Optional[str] = None
cluster: Optional[str] = None
uid: Optional[str] = None
mode: str = "commit"

View File

@@ -0,0 +1,71 @@
from providers.asr.buffered import BufferedASRService
from providers.asr.dashscope import DashScopeRealtimeASRService
from providers.asr.openai_compatible import OpenAICompatibleASRService
from providers.asr.volcengine import VolcengineRealtimeASRService
from providers.factory.default import DefaultRealtimeServiceFactory
from runtime.ports import ASRServiceSpec
def test_create_asr_service_dashscope_returns_streaming_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service(
ASRServiceSpec(
provider="dashscope",
mode="streaming",
sample_rate=16000,
api_key="test-key",
model="qwen3-asr-flash-realtime",
)
)
assert isinstance(service, DashScopeRealtimeASRService)
assert service.mode == "streaming"
def test_create_asr_service_openai_compatible_returns_offline_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service(
ASRServiceSpec(
provider="openai_compatible",
sample_rate=16000,
api_key="test-key",
model="FunAudioLLM/SenseVoiceSmall",
)
)
assert isinstance(service, OpenAICompatibleASRService)
assert service.mode == "offline"
assert service.enable_interim is False
def test_create_asr_service_volcengine_returns_streaming_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service(
ASRServiceSpec(
provider="volcengine",
mode="streaming",
sample_rate=16000,
api_key="test-key",
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
model="bigmodel",
app_id="app-1",
uid="caller-1",
request_params={"end_window_size": 800},
)
)
assert isinstance(service, VolcengineRealtimeASRService)
assert service.mode == "streaming"
assert service.protocol == "seed"
assert service.app_id == "app-1"
assert service.uid == "caller-1"
assert service.request_params["end_window_size"] == 800
def test_create_asr_service_fallback_buffered_for_unsupported_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_asr_service(
ASRServiceSpec(
provider="unknown_provider",
sample_rate=16000,
)
)
assert isinstance(service, BufferedASRService)
assert service.mode == "offline"

View File

@@ -227,6 +227,62 @@ async def test_with_backend_url_uses_backend_for_assistant_config(monkeypatch, t
assert payload["assistant"]["systemPrompt"] == "backend prompt"
def test_translate_agent_schema_maps_volcengine_fields():
payload = {
"agent": {
"tts": {
"provider": "volcengine",
"api_key": "tts-key",
"api_url": "https://openspeech.bytedance.com/api/v3/tts/unidirectional",
"app_id": "app-123",
"resource_id": "seed-tts-2.0",
"uid": "caller-1",
"voice": "zh_female_shuangkuaisisi_moon_bigtts",
"speed": 1.1,
},
"asr": {
"provider": "volcengine",
"api_key": "asr-key",
"api_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
"model": "bigmodel",
"app_id": "app-123",
"resource_id": "volc.bigasr.sauc.duration",
"uid": "caller-1",
"request_params": {
"end_window_size": 800,
"force_to_speech_time": 1000,
},
},
}
}
translated = LocalYamlAssistantConfigAdapter._translate_agent_schema("assistant_demo", payload)
assert translated is not None
assert translated["services"]["tts"] == {
"provider": "volcengine",
"apiKey": "tts-key",
"baseUrl": "https://openspeech.bytedance.com/api/v3/tts/unidirectional",
"voice": "zh_female_shuangkuaisisi_moon_bigtts",
"appId": "app-123",
"resourceId": "seed-tts-2.0",
"uid": "caller-1",
"speed": 1.1,
}
assert translated["services"]["asr"] == {
"provider": "volcengine",
"model": "bigmodel",
"apiKey": "asr-key",
"baseUrl": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
"appId": "app-123",
"resourceId": "volc.bigasr.sauc.duration",
"uid": "caller-1",
"requestParams": {
"end_window_size": 800,
"force_to_speech_time": 1000,
},
}
@pytest.mark.asyncio
async def test_backend_mode_disabled_uses_local_assistant_config_even_with_url(monkeypatch, tmp_path):
class _FailIfCalledClientSession:
@@ -282,7 +338,7 @@ async def test_local_yaml_adapter_rejects_path_traversal_like_assistant_id(tmp_p
@pytest.mark.asyncio
async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
async def test_local_yaml_translates_agent_schema_with_asr_interim_flag(tmp_path):
config_dir = tmp_path / "assistants"
config_dir.mkdir(parents=True, exist_ok=True)
(config_dir / "default.yaml").write_text(
@@ -305,6 +361,7 @@ async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
" model: asr-model",
" api_key: sk-asr",
" api_url: https://asr.example.com/v1/audio/transcriptions",
" enable_interim: false",
" duplex:",
" system_prompt: You are test assistant",
]
@@ -321,4 +378,5 @@ async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
assert services.get("llm", {}).get("apiKey") == "sk-llm"
assert services.get("tts", {}).get("apiKey") == "sk-tts"
assert services.get("asr", {}).get("apiKey") == "sk-asr"
assert services.get("asr", {}).get("enableInterim") is False
assert assistant.get("systemPrompt") == "You are test assistant"

View File

@@ -0,0 +1,67 @@
import asyncio
import pytest
from providers.asr.dashscope import DashScopeRealtimeASRService
@pytest.mark.asyncio
async def test_dashscope_asr_interim_event_emits_interim_transcript():
received = []
async def _on_transcript(text: str, is_final: bool) -> None:
received.append((text, is_final))
service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript)
service._loop = asyncio.get_running_loop()
service._running = True
service._on_ws_event(
{
"type": "conversation.item.input_audio_transcription.text",
"stash": "你好世界",
}
)
await asyncio.sleep(0.05)
result = service._transcript_queue.get_nowait()
assert result.text == "你好世界"
assert result.is_final is False
assert received == [("你好世界", False)]
@pytest.mark.asyncio
async def test_dashscope_asr_final_event_emits_final_transcript_and_final_queue():
received = []
async def _on_transcript(text: str, is_final: bool) -> None:
received.append((text, is_final))
service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript)
service._loop = asyncio.get_running_loop()
service._running = True
service._audio_sent_in_utterance = True
service._on_ws_event(
{
"type": "conversation.item.input_audio_transcription.completed",
"transcript": "最终识别结果",
}
)
await asyncio.sleep(0.05)
result = service._transcript_queue.get_nowait()
assert result.text == "最终识别结果"
assert result.is_final is True
assert service._final_queue.get_nowait() == "最终识别结果"
assert received == [("最终识别结果", True)]
@pytest.mark.asyncio
async def test_dashscope_wait_for_final_falls_back_to_latest_interim_on_timeout():
service = DashScopeRealtimeASRService(api_key="test-key")
service._audio_sent_in_utterance = True
service._last_interim_text = "部分结果"
text = await service.wait_for_final_transcription(timeout_ms=10)
assert text == "部分结果"

View File

@@ -0,0 +1,260 @@
import asyncio
from typing import Any, Dict, List
import pytest
from runtime.pipeline.duplex import DuplexPipeline
class _DummySileroVAD:
def __init__(self, *args, **kwargs):
pass
def process_audio(self, _pcm: bytes) -> float:
return 0.0
class _DummyVADProcessor:
def __init__(self, *args, **kwargs):
pass
def process(self, _speech_prob: float):
return "Silence", 0.0
class _DummyEouDetector:
def __init__(self, *args, **kwargs):
self.is_speaking = True
def process(self, _vad_status: str, force_eligible: bool = False) -> bool:
_ = force_eligible
return False
def reset(self) -> None:
self.is_speaking = False
class _FakeTransport:
async def send_event(self, _event: Dict[str, Any]) -> None:
return None
async def send_audio(self, _audio: bytes) -> None:
return None
class _FakeStreamingASR:
mode = "streaming"
def __init__(self):
self.begin_calls = 0
self.end_calls = 0
self.wait_calls = 0
self.sent_audio: List[bytes] = []
self.wait_text = ""
async def connect(self) -> None:
return None
async def disconnect(self) -> None:
return None
async def send_audio(self, audio: bytes) -> None:
self.sent_audio.append(audio)
async def receive_transcripts(self):
if False:
yield None
async def begin_utterance(self) -> None:
self.begin_calls += 1
async def end_utterance(self) -> None:
self.end_calls += 1
async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str:
_ = timeout_ms
self.wait_calls += 1
return self.wait_text
def clear_utterance(self) -> None:
return None
class _FakeOfflineASR:
mode = "offline"
def __init__(self):
self.start_interim_calls = 0
self.stop_interim_calls = 0
self.sent_audio: List[bytes] = []
self.final_text = "offline final"
async def connect(self) -> None:
return None
async def disconnect(self) -> None:
return None
async def send_audio(self, audio: bytes) -> None:
self.sent_audio.append(audio)
async def receive_transcripts(self):
if False:
yield None
async def start_interim_transcription(self) -> None:
self.start_interim_calls += 1
async def stop_interim_transcription(self) -> None:
self.stop_interim_calls += 1
async def get_final_transcription(self) -> str:
return self.final_text
def clear_buffer(self) -> None:
return None
def get_and_clear_text(self) -> str:
return self.final_text
def _build_pipeline(monkeypatch, asr_service):
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
monkeypatch.setattr("runtime.pipeline.duplex.EouDetector", _DummyEouDetector)
return DuplexPipeline(
transport=_FakeTransport(),
session_id="asr_mode_test",
asr_service=asr_service,
)
@pytest.mark.asyncio
async def test_start_asr_capture_uses_streaming_begin(monkeypatch):
asr = _FakeStreamingASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "streaming"
pipeline._pending_speech_audio = b"\x00" * 320
pipeline._pre_speech_buffer = b"\x00" * 640
await pipeline._start_asr_capture()
assert asr.begin_calls == 1
assert asr.sent_audio
assert pipeline._asr_capture_active is True
@pytest.mark.asyncio
async def test_start_asr_capture_uses_offline_interim_control_when_enabled(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._runtime_asr["enableInterim"] = True
pipeline._pending_speech_audio = b"\x00" * 320
pipeline._pre_speech_buffer = b"\x00" * 640
await pipeline._start_asr_capture()
assert asr.start_interim_calls == 1
assert asr.sent_audio
assert pipeline._asr_capture_active is True
@pytest.mark.asyncio
async def test_start_asr_capture_skips_offline_interim_control_when_disabled(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._runtime_asr["enableInterim"] = False
pipeline._pending_speech_audio = b"\x00" * 320
pipeline._pre_speech_buffer = b"\x00" * 640
await pipeline._start_asr_capture()
assert asr.start_interim_calls == 0
assert asr.sent_audio
assert pipeline._asr_capture_active is True
@pytest.mark.asyncio
async def test_offline_interim_callback_ignored_when_disabled(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._runtime_asr["enableInterim"] = False
captured_events = []
captured_deltas = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
_ = priority
captured_events.append(event)
async def _capture_delta(text: str):
captured_deltas.append(text)
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
monkeypatch.setattr(pipeline, "_emit_transcript_delta", _capture_delta)
await pipeline._on_transcript_callback("ignored interim", is_final=False)
assert captured_events == []
assert captured_deltas == []
assert pipeline._latest_asr_interim_text == ""
@pytest.mark.asyncio
async def test_offline_final_callback_emits_when_interim_disabled(monkeypatch):
asr = _FakeOfflineASR()
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "offline"
pipeline._runtime_asr["enableInterim"] = False
captured_events = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
_ = priority
captured_events.append(event)
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
await pipeline._on_transcript_callback("final only", is_final=True)
assert any(event.get("type") == "transcript.final" for event in captured_events)
@pytest.mark.asyncio
async def test_streaming_eou_falls_back_to_latest_interim(monkeypatch):
asr = _FakeStreamingASR()
asr.wait_text = ""
pipeline = _build_pipeline(monkeypatch, asr)
pipeline._asr_mode = "streaming"
pipeline._asr_capture_active = True
pipeline._latest_asr_interim_text = "fallback interim text"
await pipeline.conversation.start_user_turn()
captured_events = []
captured_turns = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
_ = priority
captured_events.append(event)
async def _noop_stop_current_speech() -> None:
return None
async def _capture_turn(user_text: str, *args, **kwargs) -> None:
_ = (args, kwargs)
captured_turns.append(user_text)
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
monkeypatch.setattr(pipeline, "_stop_current_speech", _noop_stop_current_speech)
monkeypatch.setattr(pipeline, "_handle_turn", _capture_turn)
await pipeline._on_end_of_utterance()
await asyncio.sleep(0.05)
assert asr.end_calls == 1
assert asr.wait_calls == 1
assert captured_turns == ["fallback interim text"]
assert any(event.get("type") == "transcript.final" for event in captured_events)

View File

@@ -52,9 +52,33 @@ class _FakeTTS:
class _FakeASR:
mode = "offline"
async def connect(self) -> None:
return None
async def disconnect(self) -> None:
return None
async def send_audio(self, _audio: bytes) -> None:
return None
async def receive_transcripts(self):
if False:
yield None
def clear_buffer(self) -> None:
return None
async def start_interim_transcription(self) -> None:
return None
async def stop_interim_transcription(self) -> None:
return None
async def get_final_transcription(self) -> str:
return ""
class _FakeLLM:
def __init__(self, rounds: List[List[LLMStreamEvent]]):

View File

@@ -0,0 +1,45 @@
from providers.factory.default import DefaultRealtimeServiceFactory
from providers.tts.mock import MockTTSService
from providers.tts.openai_compatible import OpenAICompatibleTTSService
from providers.tts.volcengine import VolcengineTTSService
from runtime.ports import TTSServiceSpec
def test_create_tts_service_volcengine_returns_native_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_tts_service(
TTSServiceSpec(
provider="volcengine",
api_key="test-key",
app_id="app-1",
resource_id="seed-tts-2.0",
voice="zh_female_shuangkuaisisi_moon_bigtts",
sample_rate=16000,
)
)
assert isinstance(service, VolcengineTTSService)
def test_create_tts_service_openai_compatible_returns_provider():
factory = DefaultRealtimeServiceFactory()
service = factory.create_tts_service(
TTSServiceSpec(
provider="openai_compatible",
api_key="test-key",
voice="anna",
sample_rate=16000,
)
)
assert isinstance(service, OpenAICompatibleTTSService)
def test_create_tts_service_fallbacks_to_mock_without_key():
factory = DefaultRealtimeServiceFactory()
service = factory.create_tts_service(
TTSServiceSpec(
provider="volcengine",
voice="anna",
sample_rate=16000,
)
)
assert isinstance(service, MockTTSService)

View File

@@ -0,0 +1,86 @@
import gzip
import json
from providers.asr.volcengine import VolcengineRealtimeASRService
def test_volcengine_seed_protocol_defaults_and_headers():
service = VolcengineRealtimeASRService(
api_key="access-token",
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
app_id="app-1",
uid="caller-1",
)
assert service.protocol == "seed"
assert service.resource_id == "volc.bigasr.sauc.duration"
headers = service._build_seed_headers("req-1")
assert headers == {
"X-Api-App-Key": "app-1",
"X-Api-Access-Key": "access-token",
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Request-Id": "req-1",
}
def test_volcengine_seed_start_payload_merges_request_params():
service = VolcengineRealtimeASRService(
api_key="access-token",
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
app_id="app-1",
uid="caller-1",
language="zh-CN",
request_params={
"request": {
"end_window_size": 800,
"force_to_speech_time": 1000,
"context": "{\"hotwords\":[{\"word\":\"doubao\"}]}",
},
"audio": {"codec": "raw"},
},
)
payload = service._build_seed_start_payload()
assert payload["user"] == {"uid": "caller-1"}
assert payload["audio"] == {
"format": "pcm",
"rate": 16000,
"bits": 16,
"channels": 1,
"codec": "raw",
"language": "zh-CN",
}
assert payload["request"]["model_name"] == "bigmodel"
assert payload["request"]["end_window_size"] == 800
assert payload["request"]["force_to_speech_time"] == 1000
assert payload["request"]["context"] == "{\"hotwords\":[{\"word\":\"doubao\"}]}"
def test_volcengine_seed_start_request_encodes_gzip_json_payload():
service = VolcengineRealtimeASRService(
api_key="access-token",
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
app_id="app-1",
uid="caller-1",
)
frame = service._build_seed_start_request()
assert frame[0] == 0x11
assert frame[1] == 0x11
payload_length = int.from_bytes(frame[8:12], "big")
payload = json.loads(gzip.decompress(frame[12 : 12 + payload_length]).decode("utf-8"))
assert payload["user"]["uid"] == "caller-1"
assert payload["request"]["model_name"] == "bigmodel"
def test_volcengine_gateway_protocol_keeps_model_query():
service = VolcengineRealtimeASRService(
api_key="access-token",
api_url="wss://ai-gateway.vei.volces.com/v1/realtime",
model="bigmodel",
)
assert service.protocol == "gateway"
assert service.api_url == "wss://ai-gateway.vei.volces.com/v1/realtime?model=bigmodel"

View File

@@ -259,6 +259,7 @@ export const AssistantsPage: React.FC = () => {
speed: 1,
hotwords: [],
tools: [],
asrInterimEnabled: false,
botCannotBeInterrupted: false,
interruptionSensitivity: 180,
configMode: 'platform',
@@ -1358,6 +1359,41 @@ export const AssistantsPage: React.FC = () => {
</p>
</div>
<div className="space-y-3">
<div className="flex items-center justify-between gap-3">
<label className="text-sm font-medium text-white flex items-center">
<Mic className="w-4 h-4 mr-2 text-primary"/> 线 ASR
</label>
<div className="inline-flex rounded-lg border border-white/10 bg-white/5 p-1">
<button
type="button"
onClick={() => updateAssistant('asrInterimEnabled', false)}
className={`px-3 py-1 text-xs rounded-md transition-colors ${
selectedAssistant.asrInterimEnabled === true
? 'text-muted-foreground hover:text-foreground'
: 'bg-primary text-primary-foreground shadow-sm'
}`}
>
</button>
<button
type="button"
onClick={() => updateAssistant('asrInterimEnabled', true)}
className={`px-3 py-1 text-xs rounded-md transition-colors ${
selectedAssistant.asrInterimEnabled === true
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground'
}`}
>
</button>
</div>
</div>
<p className="text-xs text-muted-foreground">
线 ASR OpenAI Compatible / buffered
</p>
</div>
<div className="space-y-3">
<div className="flex items-center justify-between gap-3">
<label className="text-sm font-medium text-white flex items-center">

View File

@@ -87,6 +87,7 @@ const mapAssistant = (raw: AnyRecord): Assistant => ({
speed: Number(readField(raw, ['speed'], 1)),
hotwords: readField(raw, ['hotwords'], []),
tools: normalizeToolIdList(readField(raw, ['tools'], [])),
asrInterimEnabled: Boolean(readField(raw, ['asrInterimEnabled', 'asr_interim_enabled'], false)),
botCannotBeInterrupted: Boolean(readField(raw, ['botCannotBeInterrupted', 'bot_cannot_be_interrupted'], false)),
interruptionSensitivity: Number(readField(raw, ['interruptionSensitivity', 'interruption_sensitivity'], 500)),
configMode: readField(raw, ['configMode', 'config_mode'], 'platform') as 'platform' | 'dify' | 'fastgpt' | 'none',
@@ -284,6 +285,7 @@ export const createAssistant = async (data: Partial<Assistant>): Promise<Assista
speed: data.speed ?? 1,
hotwords: data.hotwords || [],
tools: normalizeToolIdList(data.tools || []),
asrInterimEnabled: data.asrInterimEnabled ?? false,
botCannotBeInterrupted: data.botCannotBeInterrupted ?? false,
interruptionSensitivity: data.interruptionSensitivity ?? 500,
configMode: data.configMode || 'platform',
@@ -316,6 +318,7 @@ export const updateAssistant = async (id: string, data: Partial<Assistant>): Pro
speed: data.speed,
hotwords: data.hotwords,
tools: data.tools === undefined ? undefined : normalizeToolIdList(data.tools),
asrInterimEnabled: data.asrInterimEnabled,
botCannotBeInterrupted: data.botCannotBeInterrupted,
interruptionSensitivity: data.interruptionSensitivity,
configMode: data.configMode,

View File

@@ -19,6 +19,7 @@ export interface Assistant {
speed: number;
hotwords: string[];
tools?: string[]; // IDs of enabled tools
asrInterimEnabled?: boolean;
botCannotBeInterrupted?: boolean;
interruptionSensitivity?: number; // In ms
configMode?: 'platform' | 'dify' | 'fastgpt' | 'none';