Merge branch 'engine-v3' of https://gitea.xiaowang.eu.org/wx44wx/AI-VideoAssistant into engine-v3
This commit is contained in:
@@ -127,6 +127,7 @@ class Assistant(Base):
|
||||
speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
|
||||
tools: Mapped[dict] = mapped_column(JSON, default=list)
|
||||
asr_interim_enabled: Mapped[bool] = mapped_column(default=False)
|
||||
bot_cannot_be_interrupted: Mapped[bool] = mapped_column(default=False)
|
||||
interruption_sensitivity: Mapped[int] = mapped_column(Integer, default=500)
|
||||
config_mode: Mapped[str] = mapped_column(String(32), default="platform")
|
||||
|
||||
@@ -126,6 +126,9 @@ def _ensure_assistant_schema(db: Session) -> None:
|
||||
if "manual_opener_tool_calls" not in columns:
|
||||
db.execute(text("ALTER TABLE assistants ADD COLUMN manual_opener_tool_calls JSON"))
|
||||
altered = True
|
||||
if "asr_interim_enabled" not in columns:
|
||||
db.execute(text("ALTER TABLE assistants ADD COLUMN asr_interim_enabled BOOLEAN DEFAULT 0"))
|
||||
altered = True
|
||||
|
||||
if altered:
|
||||
db.commit()
|
||||
@@ -317,18 +320,27 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
|
||||
else:
|
||||
warnings.append(f"LLM model not found: {assistant.llm_model_id}")
|
||||
|
||||
asr_runtime: Dict[str, Any] = {
|
||||
"enableInterim": bool(assistant.asr_interim_enabled),
|
||||
}
|
||||
if assistant.asr_model_id:
|
||||
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
|
||||
if asr:
|
||||
asr_provider = "openai_compatible" if _is_openai_compatible_vendor(asr.vendor) else "buffered"
|
||||
metadata["services"]["asr"] = {
|
||||
if _is_dashscope_vendor(asr.vendor):
|
||||
asr_provider = "dashscope"
|
||||
elif _is_openai_compatible_vendor(asr.vendor):
|
||||
asr_provider = "openai_compatible"
|
||||
else:
|
||||
asr_provider = "buffered"
|
||||
asr_runtime.update({
|
||||
"provider": asr_provider,
|
||||
"model": asr.model_name or asr.name,
|
||||
"apiKey": asr.api_key if asr_provider == "openai_compatible" else None,
|
||||
"baseUrl": asr.base_url if asr_provider == "openai_compatible" else None,
|
||||
}
|
||||
"apiKey": asr.api_key if asr_provider in {"openai_compatible", "dashscope"} else None,
|
||||
"baseUrl": asr.base_url if asr_provider in {"openai_compatible", "dashscope"} else None,
|
||||
})
|
||||
else:
|
||||
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
|
||||
metadata["services"]["asr"] = asr_runtime
|
||||
|
||||
if not assistant.voice_output_enabled:
|
||||
metadata["services"]["tts"] = {"enabled": False}
|
||||
@@ -432,6 +444,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
"speed": assistant.speed,
|
||||
"hotwords": assistant.hotwords or [],
|
||||
"tools": _normalize_assistant_tool_ids(assistant.tools),
|
||||
"asrInterimEnabled": bool(assistant.asr_interim_enabled),
|
||||
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
|
||||
"interruptionSensitivity": assistant.interruption_sensitivity,
|
||||
"configMode": assistant.config_mode,
|
||||
@@ -452,6 +465,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
||||
"firstTurnMode": "first_turn_mode",
|
||||
"manualOpenerToolCalls": "manual_opener_tool_calls",
|
||||
"interruptionSensitivity": "interruption_sensitivity",
|
||||
"asrInterimEnabled": "asr_interim_enabled",
|
||||
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
|
||||
"configMode": "config_mode",
|
||||
"voiceOutputEnabled": "voice_output_enabled",
|
||||
@@ -646,6 +660,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
speed=data.speed,
|
||||
hotwords=data.hotwords,
|
||||
tools=_normalize_assistant_tool_ids(data.tools),
|
||||
asr_interim_enabled=data.asrInterimEnabled,
|
||||
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
|
||||
interruption_sensitivity=data.interruptionSensitivity,
|
||||
config_mode=data.configMode,
|
||||
|
||||
@@ -291,6 +291,7 @@ class AssistantBase(BaseModel):
|
||||
speed: float = 1.0
|
||||
hotwords: List[str] = []
|
||||
tools: List[str] = []
|
||||
asrInterimEnabled: bool = False
|
||||
botCannotBeInterrupted: bool = False
|
||||
interruptionSensitivity: int = 500
|
||||
configMode: str = "platform"
|
||||
@@ -322,6 +323,7 @@ class AssistantUpdate(BaseModel):
|
||||
speed: Optional[float] = None
|
||||
hotwords: Optional[List[str]] = None
|
||||
tools: Optional[List[str]] = None
|
||||
asrInterimEnabled: Optional[bool] = None
|
||||
botCannotBeInterrupted: Optional[bool] = None
|
||||
interruptionSensitivity: Optional[int] = None
|
||||
configMode: Optional[str] = None
|
||||
|
||||
@@ -27,6 +27,7 @@ class TestAssistantAPI:
|
||||
assert data["voiceOutputEnabled"] is True
|
||||
assert data["firstTurnMode"] == "bot_first"
|
||||
assert data["generatedOpenerEnabled"] is False
|
||||
assert data["asrInterimEnabled"] is False
|
||||
assert data["botCannotBeInterrupted"] is False
|
||||
assert "id" in data
|
||||
assert data["callCount"] == 0
|
||||
@@ -37,6 +38,7 @@ class TestAssistantAPI:
|
||||
response = client.post("/api/assistants", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Minimal Assistant"
|
||||
assert response.json()["asrInterimEnabled"] is False
|
||||
|
||||
def test_get_assistant_by_id(self, client, sample_assistant_data):
|
||||
"""Test getting a specific assistant by ID"""
|
||||
@@ -68,6 +70,7 @@ class TestAssistantAPI:
|
||||
"prompt": "You are an updated assistant.",
|
||||
"speed": 1.5,
|
||||
"voiceOutputEnabled": False,
|
||||
"asrInterimEnabled": True,
|
||||
"manualOpenerToolCalls": [
|
||||
{"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}}
|
||||
],
|
||||
@@ -79,6 +82,7 @@ class TestAssistantAPI:
|
||||
assert data["prompt"] == "You are an updated assistant."
|
||||
assert data["speed"] == 1.5
|
||||
assert data["voiceOutputEnabled"] is False
|
||||
assert data["asrInterimEnabled"] is True
|
||||
assert data["manualOpenerToolCalls"] == [
|
||||
{"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}}
|
||||
]
|
||||
@@ -213,6 +217,7 @@ class TestAssistantAPI:
|
||||
"prompt": "runtime prompt",
|
||||
"opener": "runtime opener",
|
||||
"manualOpenerToolCalls": [{"toolName": "text_msg_prompt", "arguments": {"msg": "欢迎"}}],
|
||||
"asrInterimEnabled": True,
|
||||
"speed": 1.1,
|
||||
})
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
@@ -232,6 +237,7 @@ class TestAssistantAPI:
|
||||
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
|
||||
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
|
||||
assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"]
|
||||
assert metadata["services"]["asr"]["enableInterim"] is True
|
||||
expected_tts_voice = f"{sample_voice_data['model']}:{sample_voice_data['voice_key']}"
|
||||
assert metadata["services"]["tts"]["voice"] == expected_tts_voice
|
||||
assert metadata["services"]["tts"]["baseUrl"] == sample_voice_data["base_url"]
|
||||
@@ -309,6 +315,7 @@ class TestAssistantAPI:
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
assert metadata["output"]["mode"] == "text"
|
||||
assert metadata["services"]["asr"]["enableInterim"] is False
|
||||
assert metadata["services"]["tts"]["enabled"] is False
|
||||
|
||||
def test_runtime_config_dashscope_voice_provider(self, client, sample_assistant_data):
|
||||
@@ -343,6 +350,48 @@ class TestAssistantAPI:
|
||||
assert tts["apiKey"] == "dashscope-key"
|
||||
assert tts["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
|
||||
def test_runtime_config_dashscope_asr_provider(self, client, sample_assistant_data):
|
||||
"""DashScope ASR models should map to dashscope asr provider in runtime metadata."""
|
||||
asr_resp = client.post("/api/asr", json={
|
||||
"name": "DashScope Realtime ASR",
|
||||
"vendor": "DashScope",
|
||||
"language": "zh",
|
||||
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
|
||||
"api_key": "dashscope-asr-key",
|
||||
"model_name": "qwen3-asr-flash-realtime",
|
||||
"hotwords": [],
|
||||
"enable_punctuation": True,
|
||||
"enable_normalization": True,
|
||||
"enabled": True,
|
||||
})
|
||||
assert asr_resp.status_code == 200
|
||||
asr_payload = asr_resp.json()
|
||||
|
||||
sample_assistant_data.update({
|
||||
"asrModelId": asr_payload["id"],
|
||||
})
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert assistant_resp.status_code == 200
|
||||
assistant_id = assistant_resp.json()["id"]
|
||||
|
||||
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
asr = metadata["services"]["asr"]
|
||||
assert asr["provider"] == "dashscope"
|
||||
assert asr["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
assert asr["enableInterim"] is False
|
||||
|
||||
def test_runtime_config_defaults_asr_interim_disabled_without_asr_model(self, client, sample_assistant_data):
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert assistant_resp.status_code == 200
|
||||
assistant_id = assistant_resp.json()["id"]
|
||||
|
||||
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
assert metadata["services"]["asr"]["enableInterim"] is False
|
||||
|
||||
def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data):
|
||||
sample_assistant_data.update({
|
||||
"firstTurnMode": "user_first",
|
||||
|
||||
@@ -12,6 +12,24 @@
|
||||
| **热词** | 提高业务词汇、品牌词、专有名词识别率 |
|
||||
| **标点与规范化** | 自动补全标点、规范数字和日期等 |
|
||||
|
||||
## 模式
|
||||
|
||||
- `offline`:引擎本地缓冲音频后触发识别(适用于 OpenAI-compatible / SiliconFlow)。
|
||||
- `streaming`:音频分片实时发送到服务端,服务端持续返回转写事件(适用于 DashScope Realtime ASR、Volcengine BigASR)。
|
||||
|
||||
## 配置项
|
||||
|
||||
| 配置项 | 说明 |
|
||||
|---|---|
|
||||
| ASR 引擎 | 选择语音识别服务提供商 |
|
||||
| 模型 | 识别模型名称 |
|
||||
| `enable_interim` | 是否开启离线 ASR 中间结果(默认 `false`,仅离线模式生效) |
|
||||
| `app_id` / `resource_id` | Volcengine 等厂商的应用标识与资源标识 |
|
||||
| `request_params` | 厂商原生请求参数透传,例如 `end_window_size`、`force_to_speech_time`、`context` |
|
||||
| 语言 | 中文/英文/多语言 |
|
||||
| 热词 | 提升特定词汇识别准确率 |
|
||||
| 标点与规范化 | 是否自动补全标点、文本规范化 |
|
||||
|
||||
## 选择建议
|
||||
|
||||
- 客服、外呼等业务场景建议维护热词表,并按业务线持续更新
|
||||
@@ -29,3 +47,7 @@
|
||||
|
||||
- [声音资源](voices.md) - 完整语音输入输出链路中的 TTS 侧配置
|
||||
- [快速开始](../quickstart/index.md) - 以任务路径接入第一个 ASR 资源
|
||||
- 客服场景建议开启热词并维护业务词表
|
||||
- 多语言场景建议按会话入口显式指定语言
|
||||
- 对延迟敏感场景优先选择流式识别模型
|
||||
- 当前支持提供商:`openai_compatible`、`siliconflow`、`dashscope`、`volcengine`、`buffered`(回退)
|
||||
|
||||
@@ -230,6 +230,14 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter):
|
||||
tts_runtime["baseUrl"] = cls._as_str(tts.get("api_url"))
|
||||
if cls._as_str(tts.get("voice")):
|
||||
tts_runtime["voice"] = cls._as_str(tts.get("voice"))
|
||||
if cls._as_str(tts.get("app_id")):
|
||||
tts_runtime["appId"] = cls._as_str(tts.get("app_id"))
|
||||
if cls._as_str(tts.get("resource_id")):
|
||||
tts_runtime["resourceId"] = cls._as_str(tts.get("resource_id"))
|
||||
if cls._as_str(tts.get("cluster")):
|
||||
tts_runtime["cluster"] = cls._as_str(tts.get("cluster"))
|
||||
if cls._as_str(tts.get("uid")):
|
||||
tts_runtime["uid"] = cls._as_str(tts.get("uid"))
|
||||
if tts.get("speed") is not None:
|
||||
tts_runtime["speed"] = tts.get("speed")
|
||||
dashscope_mode = cls._as_str(tts.get("dashscope_mode")) or cls._as_str(tts.get("mode"))
|
||||
@@ -249,6 +257,18 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter):
|
||||
asr_runtime["apiKey"] = cls._as_str(asr.get("api_key"))
|
||||
if cls._as_str(asr.get("api_url")):
|
||||
asr_runtime["baseUrl"] = cls._as_str(asr.get("api_url"))
|
||||
if cls._as_str(asr.get("app_id")):
|
||||
asr_runtime["appId"] = cls._as_str(asr.get("app_id"))
|
||||
if cls._as_str(asr.get("resource_id")):
|
||||
asr_runtime["resourceId"] = cls._as_str(asr.get("resource_id"))
|
||||
if cls._as_str(asr.get("cluster")):
|
||||
asr_runtime["cluster"] = cls._as_str(asr.get("cluster"))
|
||||
if cls._as_str(asr.get("uid")):
|
||||
asr_runtime["uid"] = cls._as_str(asr.get("uid"))
|
||||
if isinstance(asr.get("request_params"), dict):
|
||||
asr_runtime["requestParams"] = dict(asr.get("request_params") or {})
|
||||
if asr.get("enable_interim") is not None:
|
||||
asr_runtime["enableInterim"] = asr.get("enable_interim")
|
||||
if asr.get("interim_interval_ms") is not None:
|
||||
asr_runtime["interimIntervalMs"] = asr.get("interim_interval_ms")
|
||||
if asr.get("min_audio_ms") is not None:
|
||||
|
||||
@@ -71,11 +71,15 @@ class Settings(BaseSettings):
|
||||
# TTS Configuration
|
||||
tts_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="TTS provider (openai_compatible, siliconflow, dashscope)"
|
||||
description="TTS provider (openai_compatible, siliconflow, dashscope, volcengine)"
|
||||
)
|
||||
tts_api_url: Optional[str] = Field(default=None, description="TTS provider API URL")
|
||||
tts_model: Optional[str] = Field(default=None, description="TTS model name")
|
||||
tts_voice: str = Field(default="anna", description="TTS voice name")
|
||||
tts_app_id: Optional[str] = Field(default=None, description="Provider-specific TTS app ID")
|
||||
tts_resource_id: Optional[str] = Field(default=None, description="Provider-specific TTS resource ID")
|
||||
tts_cluster: Optional[str] = Field(default=None, description="Provider-specific TTS cluster")
|
||||
tts_uid: Optional[str] = Field(default=None, description="Provider-specific TTS user ID")
|
||||
tts_mode: str = Field(
|
||||
default="commit",
|
||||
description="DashScope-only TTS mode (commit, server_commit). Ignored for non-dashscope providers."
|
||||
@@ -85,10 +89,19 @@ class Settings(BaseSettings):
|
||||
# ASR Configuration
|
||||
asr_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="ASR provider (openai_compatible, buffered, siliconflow)"
|
||||
description="ASR provider (openai_compatible, buffered, siliconflow, dashscope, volcengine)"
|
||||
)
|
||||
asr_api_url: Optional[str] = Field(default=None, description="ASR provider API URL")
|
||||
asr_model: Optional[str] = Field(default=None, description="ASR model name")
|
||||
asr_app_id: Optional[str] = Field(default=None, description="Provider-specific ASR app ID")
|
||||
asr_resource_id: Optional[str] = Field(default=None, description="Provider-specific ASR resource ID")
|
||||
asr_cluster: Optional[str] = Field(default=None, description="Provider-specific ASR cluster")
|
||||
asr_uid: Optional[str] = Field(default=None, description="Provider-specific ASR user ID")
|
||||
asr_request_params_json: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Provider-specific ASR request params as JSON string"
|
||||
)
|
||||
asr_enable_interim: bool = Field(default=False, description="Enable interim transcripts for offline ASR")
|
||||
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
|
||||
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
|
||||
asr_start_min_speech_ms: int = Field(
|
||||
|
||||
47
engine/config/agents/dashscope.yaml
Normal file
47
engine/config/agents/dashscope.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# Agent behavior configuration for DashScope realtime ASR/TTS.
|
||||
# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers).
|
||||
# Infra/server/network settings should stay in .env.
|
||||
|
||||
agent:
|
||||
vad:
|
||||
type: silero
|
||||
model_path: data/vad/silero_vad.onnx
|
||||
threshold: 0.5
|
||||
min_speech_duration_ms: 100
|
||||
eou_threshold_ms: 800
|
||||
|
||||
llm:
|
||||
# provider: openai | openai_compatible | siliconflow
|
||||
provider: openai_compatible
|
||||
model: deepseek-v3
|
||||
temperature: 0.7
|
||||
api_key: your_llm_api_key
|
||||
api_url: https://api.qnaigc.com/v1
|
||||
|
||||
tts:
|
||||
provider: dashscope
|
||||
api_key: your_tts_api_key
|
||||
api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
model: qwen3-tts-flash-realtime
|
||||
voice: Cherry
|
||||
dashscope_mode: commit
|
||||
speed: 1.0
|
||||
|
||||
asr:
|
||||
provider: dashscope
|
||||
api_key: your_asr_api_key
|
||||
api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
model: qwen3-asr-flash-realtime
|
||||
interim_interval_ms: 500
|
||||
min_audio_ms: 300
|
||||
start_min_speech_ms: 160
|
||||
pre_speech_ms: 240
|
||||
final_tail_ms: 120
|
||||
|
||||
duplex:
|
||||
enabled: true
|
||||
system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
|
||||
|
||||
barge_in:
|
||||
min_duration_ms: 200
|
||||
silence_tolerance_ms: 60
|
||||
@@ -21,12 +21,17 @@ agent:
|
||||
api_url: https://api.qnaigc.com/v1
|
||||
|
||||
tts:
|
||||
# provider: openai_compatible | siliconflow | dashscope
|
||||
# provider: openai_compatible | siliconflow | dashscope | volcengine
|
||||
# dashscope defaults (if omitted):
|
||||
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
# model: qwen3-tts-flash-realtime
|
||||
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
|
||||
# note: dashscope_mode/mode is ONLY used when provider=dashscope.
|
||||
# volcengine defaults (if omitted):
|
||||
# api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional
|
||||
# resource_id: seed-tts-2.0
|
||||
# app_id: your volcengine app key
|
||||
# api_key: your volcengine access key
|
||||
provider: openai_compatible
|
||||
api_key: your_tts_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/speech
|
||||
@@ -35,11 +40,26 @@ agent:
|
||||
speed: 1.0
|
||||
|
||||
asr:
|
||||
# provider: buffered | openai_compatible | siliconflow
|
||||
# provider: buffered | openai_compatible | siliconflow | dashscope | volcengine
|
||||
# dashscope defaults (if omitted):
|
||||
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
# model: qwen3-asr-flash-realtime
|
||||
# note: dashscope uses streaming ASR mode (chunk-by-chunk).
|
||||
# volcengine defaults (if omitted):
|
||||
# api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
|
||||
# model: bigmodel
|
||||
# resource_id: volc.bigasr.sauc.duration
|
||||
# app_id: your volcengine app key
|
||||
# api_key: your volcengine access key
|
||||
# request_params:
|
||||
# end_window_size: 800
|
||||
# force_to_speech_time: 1000
|
||||
# note: volcengine uses streaming ASR mode (chunk-by-chunk).
|
||||
provider: openai_compatible
|
||||
api_key: you_asr_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
|
||||
model: FunAudioLLM/SenseVoiceSmall
|
||||
enable_interim: false
|
||||
interim_interval_ms: 500
|
||||
min_audio_ms: 300
|
||||
start_min_speech_ms: 160
|
||||
|
||||
@@ -18,12 +18,17 @@ agent:
|
||||
api_url: https://api.qnaigc.com/v1
|
||||
|
||||
tts:
|
||||
# provider: openai_compatible | siliconflow | dashscope
|
||||
# provider: openai_compatible | siliconflow | dashscope | volcengine
|
||||
# dashscope defaults (if omitted):
|
||||
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
# model: qwen3-tts-flash-realtime
|
||||
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
|
||||
# note: dashscope_mode/mode is ONLY used when provider=dashscope.
|
||||
# volcengine defaults (if omitted):
|
||||
# api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional
|
||||
# resource_id: seed-tts-2.0
|
||||
# app_id: your volcengine app key
|
||||
# api_key: your volcengine access key
|
||||
provider: openai_compatible
|
||||
api_key: your_tts_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/speech
|
||||
@@ -32,11 +37,26 @@ agent:
|
||||
speed: 1.0
|
||||
|
||||
asr:
|
||||
# provider: buffered | openai_compatible | siliconflow
|
||||
# provider: buffered | openai_compatible | siliconflow | dashscope | volcengine
|
||||
# dashscope defaults (if omitted):
|
||||
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
# model: qwen3-asr-flash-realtime
|
||||
# note: dashscope uses streaming ASR mode (chunk-by-chunk).
|
||||
# volcengine defaults (if omitted):
|
||||
# api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
|
||||
# model: bigmodel
|
||||
# resource_id: volc.bigasr.sauc.duration
|
||||
# app_id: your volcengine app key
|
||||
# api_key: your volcengine access key
|
||||
# request_params:
|
||||
# end_window_size: 800
|
||||
# force_to_speech_time: 1000
|
||||
# note: volcengine uses streaming ASR mode (chunk-by-chunk).
|
||||
provider: openai_compatible
|
||||
api_key: your_asr_api_key
|
||||
api_url: https://api.siliconflow.cn/v1/audio/transcriptions
|
||||
model: FunAudioLLM/SenseVoiceSmall
|
||||
enable_interim: false
|
||||
interim_interval_ms: 500
|
||||
min_audio_ms: 300
|
||||
start_min_speech_ms: 160
|
||||
|
||||
68
engine/config/agents/volcengine.yaml
Normal file
68
engine/config/agents/volcengine.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
# Agent behavior configuration (safe to edit per profile)
|
||||
# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers).
|
||||
# Infra/server/network settings should stay in .env.
|
||||
|
||||
agent:
|
||||
vad:
|
||||
type: silero
|
||||
model_path: data/vad/silero_vad.onnx
|
||||
threshold: 0.5
|
||||
min_speech_duration_ms: 100
|
||||
eou_threshold_ms: 800
|
||||
|
||||
llm:
|
||||
# provider: openai | openai_compatible | siliconflow
|
||||
provider: openai_compatible
|
||||
model: deepseek-v3
|
||||
temperature: 0.7
|
||||
# Required: no fallback. You can still reference env explicitly.
|
||||
api_key: your_llm_api_key
|
||||
# Optional for OpenAI-compatible endpoints:
|
||||
api_url: https://api.qnaigc.com/v1
|
||||
|
||||
tts:
|
||||
# provider: edge | openai_compatible | siliconflow | dashscope
|
||||
# dashscope defaults (if omitted):
|
||||
# api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime
|
||||
# model: qwen3-tts-flash-realtime
|
||||
# dashscope_mode: commit (engine splits) | server_commit (dashscope splits)
|
||||
# note: dashscope_mode/mode is ONLY used when provider=dashscope.
|
||||
# volcengine defaults (if omitted):
|
||||
provider: volcengine
|
||||
api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional
|
||||
resource_id: seed-tts-2.0
|
||||
app_id: your_tts_app_id
|
||||
api_key: your_tts_api_key
|
||||
speed: 1.1
|
||||
voice: zh_female_vv_uranus_bigtts
|
||||
|
||||
asr:
|
||||
asr:
|
||||
provider: volcengine
|
||||
api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
|
||||
app_id: your_asr_app_id
|
||||
api_key: your_asr_api_key
|
||||
resource_id: volc.bigasr.sauc.duration
|
||||
uid: caller-1
|
||||
model: bigmodel
|
||||
request_params:
|
||||
end_window_size: 800
|
||||
force_to_speech_time: 1000
|
||||
enable_punc: true
|
||||
enable_itn: false
|
||||
enable_ddc: false
|
||||
show_utterance: true
|
||||
result_type: single
|
||||
interim_interval_ms: 500
|
||||
min_audio_ms: 300
|
||||
start_min_speech_ms: 160
|
||||
pre_speech_ms: 240
|
||||
final_tail_ms: 120
|
||||
|
||||
duplex:
|
||||
enabled: true
|
||||
system_prompt: 你是一个人工智能助手,你用简答语句回答,避免使用标点符号和emoji。
|
||||
|
||||
barge_in:
|
||||
min_duration_ms: 200
|
||||
silence_tolerance_ms: 60
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -20,7 +20,7 @@ This document defines the draft port set used to keep core runtime extensible.
|
||||
- `runtime/ports/asr.py`
|
||||
- `ASRServiceSpec`
|
||||
- `ASRPort`
|
||||
- optional extensions: `ASRInterimControl`, `ASRBufferControl`
|
||||
- explicit mode ports: `OfflineASRPort`, `StreamingASRPort`
|
||||
- `runtime/ports/service_factory.py`
|
||||
- `RealtimeServiceFactory`
|
||||
|
||||
@@ -36,10 +36,10 @@ This document defines the draft port set used to keep core runtime extensible.
|
||||
- supported providers: `openai`, `openai_compatible`, `openai-compatible`, `siliconflow`
|
||||
- fallback: `MockLLMService`
|
||||
- TTS:
|
||||
- supported providers: `dashscope`, `openai_compatible`, `openai-compatible`, `siliconflow`
|
||||
- supported providers: `dashscope`, `volcengine`, `openai_compatible`, `openai-compatible`, `siliconflow`
|
||||
- fallback: `MockTTSService`
|
||||
- ASR:
|
||||
- supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`
|
||||
- supported providers: `openai_compatible`, `openai-compatible`, `siliconflow`, `dashscope`, `volcengine`
|
||||
- fallback: `BufferedASRService`
|
||||
|
||||
## Notes
|
||||
|
||||
@@ -1 +1,15 @@
|
||||
"""ASR providers."""
|
||||
|
||||
from providers.asr.buffered import BufferedASRService, MockASRService
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
from providers.asr.openai_compatible import OpenAICompatibleASRService, SiliconFlowASRService
|
||||
from providers.asr.volcengine import VolcengineRealtimeASRService
|
||||
|
||||
__all__ = [
|
||||
"BufferedASRService",
|
||||
"MockASRService",
|
||||
"DashScopeRealtimeASRService",
|
||||
"OpenAICompatibleASRService",
|
||||
"SiliconFlowASRService",
|
||||
"VolcengineRealtimeASRService",
|
||||
]
|
||||
|
||||
@@ -34,6 +34,7 @@ class BufferedASRService(BaseASRService):
|
||||
language: str = "en"
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "offline"
|
||||
|
||||
self._audio_buffer: bytes = b""
|
||||
self._current_text: str = ""
|
||||
@@ -87,6 +88,23 @@ class BufferedASRService(BaseASRService):
|
||||
self._audio_buffer = b""
|
||||
return text
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
"""Offline compatibility method used by DuplexPipeline."""
|
||||
return self.get_and_clear_text()
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""Offline compatibility method used by DuplexPipeline."""
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
"""No-op for plain buffered ASR."""
|
||||
return None
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
"""No-op for plain buffered ASR."""
|
||||
return None
|
||||
|
||||
def get_audio_buffer(self) -> bytes:
|
||||
"""Get accumulated audio buffer."""
|
||||
return self._audio_buffer
|
||||
@@ -103,6 +121,7 @@ class MockASRService(BaseASRService):
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "offline"
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
self._mock_texts = [
|
||||
"Hello, how are you?",
|
||||
@@ -145,3 +164,18 @@ class MockASRService(BaseASRService):
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
return ""
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
return None
|
||||
|
||||
388
engine/providers/asr/dashscope.py
Normal file
388
engine/providers/asr/dashscope.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""DashScope realtime streaming ASR service.
|
||||
|
||||
Uses Qwen-ASR-Realtime via DashScope Python SDK.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import ASRResult, BaseASRService, ServiceState
|
||||
|
||||
try:
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import MultiModality, OmniRealtimeCallback, OmniRealtimeConversation
|
||||
|
||||
# Some SDK builds keep TranscriptionParams under qwen_omni.omni_realtime.
|
||||
try:
|
||||
from dashscope.audio.qwen_omni import TranscriptionParams
|
||||
except ImportError:
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams
|
||||
|
||||
DASHSCOPE_SDK_AVAILABLE = True
|
||||
DASHSCOPE_IMPORT_ERROR = ""
|
||||
except Exception as exc:
|
||||
DASHSCOPE_IMPORT_ERROR = f"{type(exc).__name__}: {exc}"
|
||||
dashscope = None # type: ignore[assignment]
|
||||
MultiModality = None # type: ignore[assignment]
|
||||
OmniRealtimeConversation = None # type: ignore[assignment]
|
||||
TranscriptionParams = None # type: ignore[assignment]
|
||||
DASHSCOPE_SDK_AVAILABLE = False
|
||||
|
||||
class OmniRealtimeCallback: # type: ignore[no-redef]
|
||||
"""Fallback callback base when DashScope SDK is unavailable."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _DashScopeASRCallback(OmniRealtimeCallback):
|
||||
"""Bridge DashScope SDK callbacks into asyncio loop-safe handlers."""
|
||||
|
||||
def __init__(self, owner: "DashScopeRealtimeASRService", loop: asyncio.AbstractEventLoop):
|
||||
super().__init__()
|
||||
self._owner = owner
|
||||
self._loop = loop
|
||||
|
||||
def _schedule(self, fn: Callable[[], None]) -> None:
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(fn)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
def on_open(self) -> None:
|
||||
self._schedule(self._owner._on_ws_open)
|
||||
|
||||
def on_close(self, code: int, msg: str) -> None:
|
||||
self._schedule(lambda: self._owner._on_ws_close(code, msg))
|
||||
|
||||
def on_event(self, message: Any) -> None:
|
||||
self._schedule(lambda: self._owner._on_ws_event(message))
|
||||
|
||||
def on_error(self, message: Any) -> None:
|
||||
self._schedule(lambda: self._owner._on_ws_error(message))
|
||||
|
||||
|
||||
class DashScopeRealtimeASRService(BaseASRService):
|
||||
"""Realtime streaming ASR implementation for DashScope Qwen-ASR-Realtime."""
|
||||
|
||||
DEFAULT_WS_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
DEFAULT_MODEL = "qwen3-asr-flash-realtime"
|
||||
DEFAULT_FINAL_TIMEOUT_MS = 800
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "streaming"
|
||||
self.api_key = (
|
||||
api_key
|
||||
or os.getenv("DASHSCOPE_API_KEY")
|
||||
or os.getenv("ASR_API_KEY")
|
||||
)
|
||||
self.api_url = api_url or os.getenv("DASHSCOPE_ASR_API_URL") or self.DEFAULT_WS_URL
|
||||
self.model = model or os.getenv("DASHSCOPE_ASR_MODEL") or self.DEFAULT_MODEL
|
||||
self.on_transcript = on_transcript
|
||||
|
||||
self._client: Optional[Any] = None
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._callback: Optional[_DashScopeASRCallback] = None
|
||||
|
||||
self._running = False
|
||||
self._session_ready = asyncio.Event()
|
||||
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
|
||||
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
|
||||
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._last_interim_text = ""
|
||||
self._last_error: Optional[str] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not DASHSCOPE_SDK_AVAILABLE:
|
||||
py_exec = sys.executable
|
||||
hint = f"`{py_exec} -m pip install dashscope>=1.25.6`"
|
||||
detail = f"; import error: {DASHSCOPE_IMPORT_ERROR}" if DASHSCOPE_IMPORT_ERROR else ""
|
||||
raise RuntimeError(
|
||||
f"dashscope SDK unavailable in interpreter {py_exec}; install with {hint}{detail}"
|
||||
)
|
||||
if not self.api_key:
|
||||
raise ValueError("DashScope ASR API key not provided. Configure agent.asr.api_key in YAML.")
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._callback = _DashScopeASRCallback(owner=self, loop=self._loop)
|
||||
|
||||
if dashscope is not None:
|
||||
dashscope.api_key = self.api_key
|
||||
|
||||
self._client = OmniRealtimeConversation( # type: ignore[misc]
|
||||
model=self.model,
|
||||
url=self.api_url,
|
||||
callback=self._callback,
|
||||
)
|
||||
await asyncio.to_thread(self._client.connect)
|
||||
await self._configure_session()
|
||||
|
||||
self._running = True
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(
|
||||
"DashScope realtime ASR connected: model={}, sample_rate={}, language={}",
|
||||
self.model,
|
||||
self.sample_rate,
|
||||
self.language,
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._running = False
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._drain_queue(self._final_queue)
|
||||
self._drain_queue(self._transcript_queue)
|
||||
self._session_ready.clear()
|
||||
|
||||
if self._client is not None:
|
||||
close_fn = getattr(self._client, "close", None)
|
||||
if callable(close_fn):
|
||||
await asyncio.to_thread(close_fn)
|
||||
self._client = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("DashScope realtime ASR disconnected")
|
||||
|
||||
async def begin_utterance(self) -> None:
|
||||
self.clear_utterance()
|
||||
self._utterance_active = True
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
if not self._client:
|
||||
raise RuntimeError("DashScope ASR service not connected")
|
||||
if not audio:
|
||||
return
|
||||
|
||||
if not self._utterance_active:
|
||||
# Allow graceful fallback if caller sends before begin_utterance.
|
||||
self._utterance_active = True
|
||||
|
||||
audio_b64 = base64.b64encode(audio).decode("ascii")
|
||||
append_fn = getattr(self._client, "append_audio", None)
|
||||
if not callable(append_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing append_audio method")
|
||||
await asyncio.to_thread(append_fn, audio_b64)
|
||||
self._audio_sent_in_utterance = True
|
||||
|
||||
async def end_utterance(self) -> None:
|
||||
if not self._client:
|
||||
return
|
||||
if not self._utterance_active or not self._audio_sent_in_utterance:
|
||||
return
|
||||
|
||||
commit_fn = getattr(self._client, "commit", None)
|
||||
if not callable(commit_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing commit method")
|
||||
await asyncio.to_thread(commit_fn)
|
||||
self._utterance_active = False
|
||||
|
||||
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
|
||||
if not self._audio_sent_in_utterance:
|
||||
return ""
|
||||
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
|
||||
try:
|
||||
text = await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec)
|
||||
return str(text or "").strip()
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("DashScope ASR final timeout ({}ms), fallback to last interim", timeout_ms)
|
||||
return str(self._last_interim_text or "").strip()
|
||||
|
||||
def clear_utterance(self) -> None:
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._last_interim_text = ""
|
||||
self._last_error = None
|
||||
self._drain_queue(self._final_queue)
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
while self._running:
|
||||
try:
|
||||
result = await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _configure_session(self) -> None:
|
||||
if not self._client:
|
||||
raise RuntimeError("DashScope ASR client is not initialized")
|
||||
|
||||
text_modality: Any = "text"
|
||||
if MultiModality is not None and hasattr(MultiModality, "TEXT"):
|
||||
text_modality = MultiModality.TEXT
|
||||
|
||||
transcription_params: Optional[Any] = None
|
||||
if TranscriptionParams is not None:
|
||||
try:
|
||||
lang = "zh" if self.language == "auto" else self.language
|
||||
transcription_params = TranscriptionParams(
|
||||
language=lang,
|
||||
sample_rate=self.sample_rate,
|
||||
input_audio_format="pcm",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("DashScope ASR TranscriptionParams init failed: {}", exc)
|
||||
transcription_params = None
|
||||
|
||||
update_attempts = [
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
"enable_turn_detection": False,
|
||||
"enable_input_audio_transcription": True,
|
||||
"transcription_params": transcription_params,
|
||||
},
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
"enable_turn_detection": False,
|
||||
"enable_input_audio_transcription": True,
|
||||
},
|
||||
{
|
||||
"output_modalities": [text_modality],
|
||||
},
|
||||
]
|
||||
|
||||
update_fn = getattr(self._client, "update_session", None)
|
||||
if not callable(update_fn):
|
||||
raise RuntimeError("DashScope ASR SDK missing update_session method")
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for params in update_attempts:
|
||||
if params.get("transcription_params") is None:
|
||||
params = {k: v for k, v in params.items() if k != "transcription_params"}
|
||||
try:
|
||||
await asyncio.to_thread(update_fn, **params)
|
||||
break
|
||||
except TypeError as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"DashScope ASR session.update failed: {last_error}")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._session_ready.wait(), timeout=6.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("DashScope ASR session ready wait timeout; continuing")
|
||||
|
||||
def _on_ws_open(self) -> None:
|
||||
return None
|
||||
|
||||
def _on_ws_close(self, code: int, msg: str) -> None:
|
||||
self._last_error = f"DashScope ASR websocket closed: {code} {msg}"
|
||||
logger.debug(self._last_error)
|
||||
|
||||
def _on_ws_error(self, message: Any) -> None:
|
||||
self._last_error = str(message)
|
||||
logger.error("DashScope ASR error: {}", self._last_error)
|
||||
|
||||
def _on_ws_event(self, message: Any) -> None:
|
||||
payload = self._coerce_event(message)
|
||||
event_type = str(payload.get("type") or "").strip()
|
||||
if not event_type:
|
||||
return
|
||||
|
||||
if event_type in {"session.created", "session.updated"}:
|
||||
self._session_ready.set()
|
||||
return
|
||||
if event_type == "error" or event_type.endswith(".failed"):
|
||||
err_text = self._extract_text(payload, keys=("message", "error", "details"))
|
||||
self._last_error = err_text or event_type
|
||||
logger.error("DashScope ASR server event error: {}", self._last_error)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.text":
|
||||
stash_text = self._extract_text(payload, keys=("stash", "text", "transcript"))
|
||||
self._emit_transcript(stash_text, is_final=False)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
final_text = self._extract_text(payload, keys=("transcript", "text", "stash"))
|
||||
self._emit_transcript(final_text, is_final=True)
|
||||
return
|
||||
|
||||
def _emit_transcript(self, text: str, *, is_final: bool) -> None:
|
||||
normalized = str(text or "").strip()
|
||||
if not normalized:
|
||||
return
|
||||
if not is_final and normalized == self._last_interim_text:
|
||||
return
|
||||
if not is_final:
|
||||
self._last_interim_text = normalized
|
||||
|
||||
if self._loop is None:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._publish_transcript(normalized, is_final=is_final),
|
||||
self._loop,
|
||||
)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
async def _publish_transcript(self, text: str, *, is_final: bool) -> None:
|
||||
await self._transcript_queue.put(ASRResult(text=text, is_final=is_final))
|
||||
if is_final:
|
||||
await self._final_queue.put(text)
|
||||
if self.on_transcript:
|
||||
try:
|
||||
await self.on_transcript(text, is_final)
|
||||
except Exception as exc:
|
||||
logger.warning("DashScope ASR transcript callback failed: {}", exc)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_event(message: Any) -> Dict[str, Any]:
|
||||
if isinstance(message, dict):
|
||||
return message
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
parsed = json.loads(message)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
return {"type": "raw", "text": message}
|
||||
return {"type": "raw", "text": str(message)}
|
||||
|
||||
def _extract_text(self, payload: Dict[str, Any], *, keys: tuple[str, ...]) -> str:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
nested = self._extract_text(value, keys=keys)
|
||||
if nested:
|
||||
return nested
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, dict):
|
||||
nested = self._extract_text(value, keys=keys)
|
||||
if nested:
|
||||
return nested
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
|
||||
while True:
|
||||
try:
|
||||
queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
@@ -53,6 +53,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
model: str = "FunAudioLLM/SenseVoiceSmall",
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
enable_interim: bool = False,
|
||||
interim_interval_ms: int = 500, # How often to send interim results
|
||||
min_audio_for_interim_ms: int = 300, # Min audio before first interim
|
||||
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None
|
||||
@@ -66,11 +67,13 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
model: ASR model name or alias
|
||||
sample_rate: Audio sample rate (16000 recommended)
|
||||
language: Language code (auto for automatic detection)
|
||||
enable_interim: Whether to generate interim transcriptions in offline mode
|
||||
interim_interval_ms: How often to generate interim transcriptions
|
||||
min_audio_for_interim_ms: Minimum audio duration before first interim
|
||||
on_transcript: Callback for transcription results (text, is_final)
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "offline"
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
|
||||
@@ -79,6 +82,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
raw_api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL
|
||||
self.api_url = self._resolve_transcriptions_endpoint(raw_api_url)
|
||||
self.model = self.MODELS.get(model.lower(), model)
|
||||
self.enable_interim = bool(enable_interim)
|
||||
self.interim_interval_ms = interim_interval_ms
|
||||
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
||||
self.on_transcript = on_transcript
|
||||
@@ -181,6 +185,9 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
logger.warning("ASR session not connected")
|
||||
return None
|
||||
|
||||
if not is_final and not self.enable_interim:
|
||||
return None
|
||||
|
||||
# Check minimum audio duration
|
||||
audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||
|
||||
@@ -309,6 +316,9 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
This periodically transcribes buffered audio for
|
||||
real-time feedback to the user.
|
||||
"""
|
||||
if not self.enable_interim:
|
||||
return
|
||||
|
||||
if self._interim_task and not self._interim_task.done():
|
||||
return
|
||||
|
||||
|
||||
666
engine/providers/asr/volcengine.py
Normal file
666
engine/providers/asr/volcengine.py
Normal file
@@ -0,0 +1,666 @@
|
||||
"""Volcengine realtime ASR service.
|
||||
|
||||
Supports both:
|
||||
- Volcengine Edge Gateway realtime transcription websocket, and
|
||||
- Volcengine BigASR Seed websocket at openspeech.bytedance.com/api/v3/sauc/bigmodel.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import ASRResult, BaseASRService, ServiceState
|
||||
|
||||
VolcengineASRProtocol = Literal["gateway", "seed"]
|
||||
|
||||
|
||||
class VolcengineRealtimeASRService(BaseASRService):
|
||||
"""Realtime streaming ASR backed by Volcengine websocket APIs."""
|
||||
|
||||
DEFAULT_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
|
||||
DEFAULT_GATEWAY_WS_URL = "wss://ai-gateway.vei.volces.com/v1/realtime"
|
||||
DEFAULT_MODEL = "bigmodel"
|
||||
DEFAULT_FINAL_TIMEOUT_MS = 1200
|
||||
DEFAULT_SEED_RESOURCE_ID = "volc.bigasr.sauc.duration"
|
||||
_SEED_FRAME_MS = 100
|
||||
_SEED_PROTOCOL_VERSION = 0b0001
|
||||
_SEED_FULL_CLIENT_REQUEST = 0b0001
|
||||
_SEED_AUDIO_ONLY_REQUEST = 0b0010
|
||||
_SEED_FULL_SERVER_RESPONSE = 0b1001
|
||||
_SEED_SERVER_ACK = 0b1011
|
||||
_SEED_SERVER_ERROR_RESPONSE = 0b1111
|
||||
_SEED_NO_SEQUENCE = 0b0000
|
||||
_SEED_POS_SEQUENCE = 0b0001
|
||||
_SEED_NEG_WITH_SEQUENCE = 0b0011
|
||||
_SEED_NO_SERIALIZATION = 0b0000
|
||||
_SEED_JSON = 0b0001
|
||||
_SEED_NO_COMPRESSION = 0b0000
|
||||
_SEED_GZIP = 0b0001
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
app_id: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
uid: Optional[str] = None,
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self.mode = "streaming"
|
||||
self.api_key = api_key or os.getenv("VOLCENGINE_ASR_API_KEY") or os.getenv("ASR_API_KEY")
|
||||
self.model = str(model or os.getenv("VOLCENGINE_ASR_MODEL") or self.DEFAULT_MODEL).strip()
|
||||
raw_api_url = api_url or os.getenv("VOLCENGINE_ASR_API_URL") or self.DEFAULT_WS_URL
|
||||
self.protocol = self._detect_protocol(raw_api_url)
|
||||
self.api_url = self._resolve_api_url(raw_api_url, self.model, self.protocol)
|
||||
self.app_id = app_id or os.getenv("VOLCENGINE_ASR_APP_ID") or os.getenv("ASR_APP_ID")
|
||||
self.resource_id = (
|
||||
resource_id
|
||||
or os.getenv("VOLCENGINE_ASR_RESOURCE_ID")
|
||||
or (self.DEFAULT_SEED_RESOURCE_ID if self.protocol == "seed" else None)
|
||||
)
|
||||
self.uid = uid or os.getenv("VOLCENGINE_ASR_UID")
|
||||
self.request_params = self._load_request_params(request_params)
|
||||
self.on_transcript = on_transcript
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self._reader_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
self._running = False
|
||||
self._session_ready = asyncio.Event()
|
||||
self._transcript_queue: "asyncio.Queue[ASRResult]" = asyncio.Queue()
|
||||
self._final_queue: "asyncio.Queue[str]" = asyncio.Queue()
|
||||
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._last_interim_text = ""
|
||||
self._last_error: Optional[str] = None
|
||||
|
||||
self._seed_audio_buffer = bytearray()
|
||||
self._seed_sequence = 1
|
||||
self._seed_request_id: Optional[str] = None
|
||||
self._seed_frame_bytes = max(2, int((self.sample_rate * self._SEED_FRAME_MS / 1000) * 2))
|
||||
|
||||
@classmethod
|
||||
def _detect_protocol(cls, api_url: str) -> VolcengineASRProtocol:
|
||||
parsed = urlparse(str(api_url or "").strip())
|
||||
host = parsed.netloc.lower()
|
||||
path = parsed.path.lower()
|
||||
if "openspeech.bytedance.com" in host and "/api/v3/sauc/bigmodel" in path:
|
||||
return "seed"
|
||||
return "gateway"
|
||||
|
||||
@classmethod
|
||||
def _resolve_api_url(cls, api_url: str, model: str, protocol: VolcengineASRProtocol) -> str:
|
||||
raw = str(api_url or "").strip()
|
||||
if not raw:
|
||||
raw = cls.DEFAULT_WS_URL if protocol == "seed" else cls.DEFAULT_GATEWAY_WS_URL
|
||||
if protocol != "gateway":
|
||||
return raw
|
||||
|
||||
parsed = urlparse(raw)
|
||||
query = dict(parse_qsl(parsed.query, keep_blank_values=True))
|
||||
query.setdefault("model", model or cls.DEFAULT_MODEL)
|
||||
return urlunparse(parsed._replace(query=urlencode(query)))
|
||||
|
||||
@staticmethod
|
||||
def _load_request_params(request_params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
if isinstance(request_params, dict):
|
||||
return dict(request_params)
|
||||
raw = os.getenv("VOLCENGINE_ASR_REQUEST_PARAMS_JSON", "").strip()
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Ignoring invalid VOLCENGINE_ASR_REQUEST_PARAMS_JSON")
|
||||
return {}
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not self.api_key:
|
||||
raise ValueError("Volcengine ASR API key not provided. Configure agent.asr.api_key in YAML.")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
self._running = True
|
||||
|
||||
if self.protocol == "gateway":
|
||||
await self._connect_gateway()
|
||||
logger.info(
|
||||
"Volcengine gateway ASR connected: model={}, sample_rate={}, url={}",
|
||||
self.model,
|
||||
self.sample_rate,
|
||||
self.api_url,
|
||||
)
|
||||
else:
|
||||
if not self.app_id:
|
||||
raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.")
|
||||
logger.info(
|
||||
"Volcengine BigASR Seed ready: model={}, sample_rate={}, resource_id={}",
|
||||
self.model,
|
||||
self.sample_rate,
|
||||
self.resource_id,
|
||||
)
|
||||
|
||||
self.state = ServiceState.CONNECTED
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._running = False
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._session_ready.clear()
|
||||
self._seed_audio_buffer = bytearray()
|
||||
self._drain_queue(self._final_queue)
|
||||
self._drain_queue(self._transcript_queue)
|
||||
|
||||
await self._close_ws()
|
||||
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Volcengine ASR disconnected")
|
||||
|
||||
async def begin_utterance(self) -> None:
|
||||
self.clear_utterance()
|
||||
if self.protocol == "seed":
|
||||
await self._open_seed_stream()
|
||||
self._utterance_active = True
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
if not audio:
|
||||
return
|
||||
|
||||
if self.protocol == "seed":
|
||||
await self._send_seed_audio(audio)
|
||||
return
|
||||
|
||||
if not self._ws:
|
||||
raise RuntimeError("Volcengine ASR websocket is not connected")
|
||||
if not self._utterance_active:
|
||||
self._utterance_active = True
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(audio).decode("ascii"),
|
||||
}
|
||||
)
|
||||
self._audio_sent_in_utterance = True
|
||||
|
||||
async def end_utterance(self) -> None:
|
||||
if not self._utterance_active:
|
||||
return
|
||||
|
||||
if self.protocol == "seed":
|
||||
await self._end_seed_utterance()
|
||||
return
|
||||
|
||||
if not self._ws or not self._audio_sent_in_utterance:
|
||||
return
|
||||
await self._ws.send_json({"type": "input_audio_buffer.commit"})
|
||||
self._utterance_active = False
|
||||
|
||||
async def wait_for_final_transcription(self, timeout_ms: int = DEFAULT_FINAL_TIMEOUT_MS) -> str:
|
||||
if not self._audio_sent_in_utterance:
|
||||
return ""
|
||||
|
||||
timeout_sec = max(0.05, float(timeout_ms) / 1000.0)
|
||||
try:
|
||||
return str(await asyncio.wait_for(self._final_queue.get(), timeout=timeout_sec) or "").strip()
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Volcengine ASR final timeout ({}ms), fallback to last interim", timeout_ms)
|
||||
return str(self._last_interim_text or "").strip()
|
||||
finally:
|
||||
if self.protocol == "seed":
|
||||
await self._close_ws()
|
||||
|
||||
def clear_utterance(self) -> None:
|
||||
self._utterance_active = False
|
||||
self._audio_sent_in_utterance = False
|
||||
self._last_interim_text = ""
|
||||
self._last_error = None
|
||||
self._seed_audio_buffer = bytearray()
|
||||
self._seed_sequence = 1
|
||||
self._seed_request_id = None
|
||||
self._drain_queue(self._final_queue)
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
while self._running:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _connect_gateway(self) -> None:
|
||||
assert self._session is not None
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
if self.resource_id:
|
||||
headers["X-Api-Resource-Id"] = self.resource_id
|
||||
|
||||
self._ws = await self._session.ws_connect(self.api_url, headers=headers, heartbeat=20)
|
||||
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||
await self._configure_gateway_session()
|
||||
|
||||
async def _configure_gateway_session(self) -> None:
|
||||
if not self._ws:
|
||||
raise RuntimeError("Volcengine ASR websocket is not initialized")
|
||||
|
||||
session_payload: Dict[str, Any] = {
|
||||
"input_audio_format": "pcm",
|
||||
"input_audio_codec": "raw",
|
||||
"input_audio_sample_rate": self.sample_rate,
|
||||
"input_audio_bits": 16,
|
||||
"input_audio_channel": 1,
|
||||
"result_type": 0,
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
}
|
||||
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"type": "transcription_session.update",
|
||||
"session": session_payload,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._session_ready.wait(), timeout=8.0)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise RuntimeError("Volcengine ASR session update timeout") from exc
|
||||
|
||||
async def _open_seed_stream(self) -> None:
|
||||
if not self._session:
|
||||
raise RuntimeError("Volcengine ASR session is not initialized")
|
||||
|
||||
await self._close_ws()
|
||||
self._seed_request_id = uuid.uuid4().hex
|
||||
headers = self._build_seed_headers(self._seed_request_id)
|
||||
self._ws = await self._session.ws_connect(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
heartbeat=20,
|
||||
max_msg_size=1_000_000_000,
|
||||
)
|
||||
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||
await self._ws.send_bytes(self._build_seed_start_request())
|
||||
|
||||
async def _send_seed_audio(self, audio: bytes) -> None:
|
||||
if not self._utterance_active:
|
||||
await self.begin_utterance()
|
||||
if not self._ws:
|
||||
raise RuntimeError("Volcengine BigASR websocket is not connected")
|
||||
|
||||
self._seed_audio_buffer.extend(audio)
|
||||
while len(self._seed_audio_buffer) >= self._seed_frame_bytes:
|
||||
chunk = bytes(self._seed_audio_buffer[: self._seed_frame_bytes])
|
||||
del self._seed_audio_buffer[: self._seed_frame_bytes]
|
||||
self._seed_sequence += 1
|
||||
await self._ws.send_bytes(self._build_seed_audio_request(chunk, sequence=self._seed_sequence))
|
||||
self._audio_sent_in_utterance = True
|
||||
|
||||
async def _end_seed_utterance(self) -> None:
|
||||
if not self._ws:
|
||||
return
|
||||
if not self._audio_sent_in_utterance and not self._seed_audio_buffer:
|
||||
self._utterance_active = False
|
||||
return
|
||||
|
||||
final_chunk = bytes(self._seed_audio_buffer)
|
||||
self._seed_audio_buffer = bytearray()
|
||||
self._seed_sequence += 1
|
||||
await self._ws.send_bytes(
|
||||
self._build_seed_audio_request(final_chunk, sequence=-self._seed_sequence, is_last=True)
|
||||
)
|
||||
self._audio_sent_in_utterance = True
|
||||
self._utterance_active = False
|
||||
|
||||
async def _close_ws(self) -> None:
|
||||
reader_task = self._reader_task
|
||||
ws = self._ws
|
||||
self._reader_task = None
|
||||
self._ws = None
|
||||
|
||||
if reader_task:
|
||||
reader_task.cancel()
|
||||
try:
|
||||
await reader_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if ws is not None:
|
||||
await ws.close()
|
||||
|
||||
async def _reader_loop(self) -> None:
|
||||
ws = self._ws
|
||||
if ws is None:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
if self.protocol == "gateway":
|
||||
self._handle_gateway_event(msg.data)
|
||||
else:
|
||||
self._handle_seed_text(msg.data)
|
||||
continue
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
if self.protocol == "seed":
|
||||
self._handle_seed_binary(msg.data)
|
||||
continue
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._last_error = str(ws.exception())
|
||||
logger.error("Volcengine ASR websocket error: {}", self._last_error)
|
||||
break
|
||||
if msg.type in {aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE}:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._last_error = str(exc)
|
||||
logger.error("Volcengine ASR reader loop failed: {}", exc)
|
||||
finally:
|
||||
if self._ws is ws:
|
||||
self._ws = None
|
||||
|
||||
def _handle_gateway_event(self, message: str) -> None:
|
||||
payload = self._coerce_event(message)
|
||||
event_type = str(payload.get("type") or "").strip()
|
||||
if not event_type:
|
||||
return
|
||||
|
||||
if event_type in {"transcription_session.created", "transcription_session.updated"}:
|
||||
self._session_ready.set()
|
||||
return
|
||||
|
||||
if event_type == "error":
|
||||
self._last_error = self._extract_text(payload, ("message", "error"))
|
||||
logger.error("Volcengine ASR server error: {}", self._last_error or "unknown")
|
||||
return
|
||||
|
||||
if event_type.endswith(".failed"):
|
||||
self._last_error = self._extract_text(payload, ("message", "error", "transcript"))
|
||||
logger.error("Volcengine ASR failed event: {}", self._last_error or event_type)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.result":
|
||||
transcript = self._extract_text(payload, ("transcript", "result"))
|
||||
self._emit_transcript_sync(transcript, is_final=False)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.delta":
|
||||
transcript = self._extract_text(payload, ("delta",))
|
||||
self._emit_transcript_sync(transcript, is_final=False)
|
||||
return
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
transcript = self._extract_text(payload, ("transcript", "result"))
|
||||
self._emit_transcript_sync(transcript, is_final=True)
|
||||
|
||||
def _handle_seed_text(self, message: str) -> None:
|
||||
payload = self._coerce_event(message)
|
||||
if payload.get("type") == "error":
|
||||
self._last_error = self._extract_text(payload, ("message", "error"))
|
||||
logger.error("Volcengine BigASR error: {}", self._last_error or "unknown")
|
||||
|
||||
def _handle_seed_binary(self, message: bytes) -> None:
|
||||
payload = self._parse_seed_response(message)
|
||||
if payload.get("code"):
|
||||
self._last_error = self._extract_text(payload, ("payload_msg",))
|
||||
logger.error("Volcengine BigASR server error: {}", self._last_error or payload["code"])
|
||||
return
|
||||
|
||||
body = payload.get("payload_msg")
|
||||
if not isinstance(body, dict):
|
||||
return
|
||||
result = body.get("result")
|
||||
if not isinstance(result, dict):
|
||||
return
|
||||
|
||||
text = str(result.get("text") or "").strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
utterances = result.get("utterances")
|
||||
if not isinstance(utterances, list) or not utterances:
|
||||
return
|
||||
first_utterance = utterances[0] if isinstance(utterances[0], dict) else {}
|
||||
is_final = self._coerce_bool(first_utterance.get("definite")) is True
|
||||
self._emit_transcript_sync(text, is_final=is_final)
|
||||
|
||||
def _emit_transcript_sync(self, text: str, *, is_final: bool) -> None:
|
||||
cleaned = str(text or "").strip()
|
||||
if not cleaned:
|
||||
return
|
||||
|
||||
if not is_final:
|
||||
self._last_interim_text = cleaned
|
||||
else:
|
||||
self._last_interim_text = ""
|
||||
|
||||
result = ASRResult(text=cleaned, is_final=is_final)
|
||||
try:
|
||||
self._transcript_queue.put_nowait(result)
|
||||
except asyncio.QueueFull:
|
||||
logger.debug("Volcengine ASR transcript queue full; dropping transcript")
|
||||
|
||||
if is_final:
|
||||
try:
|
||||
self._final_queue.put_nowait(cleaned)
|
||||
except asyncio.QueueFull:
|
||||
logger.debug("Volcengine ASR final queue full; dropping transcript")
|
||||
|
||||
if self.on_transcript:
|
||||
asyncio.create_task(self.on_transcript(cleaned, is_final))
|
||||
|
||||
def _build_seed_headers(self, request_id: str) -> Dict[str, str]:
|
||||
if not self.app_id:
|
||||
raise ValueError("Volcengine ASR app_id not provided. Configure agent.asr.app_id in YAML.")
|
||||
if not self.api_key:
|
||||
raise ValueError("Volcengine ASR api_key not provided. Configure agent.asr.api_key in YAML.")
|
||||
|
||||
return {
|
||||
"X-Api-App-Key": str(self.app_id),
|
||||
"X-Api-Access-Key": str(self.api_key),
|
||||
"X-Api-Resource-Id": str(self.resource_id or self.DEFAULT_SEED_RESOURCE_ID),
|
||||
"X-Api-Request-Id": str(request_id),
|
||||
}
|
||||
|
||||
def _build_seed_start_payload(self) -> Dict[str, Any]:
|
||||
user_payload: Dict[str, Any] = {"uid": str(self.uid or self._seed_request_id or self.app_id or uuid.uuid4().hex)}
|
||||
audio_payload: Dict[str, Any] = {
|
||||
"format": "pcm",
|
||||
"rate": self.sample_rate,
|
||||
"bits": 16,
|
||||
"channels": 1,
|
||||
"codec": "raw",
|
||||
}
|
||||
if self.language and self.language != "auto":
|
||||
audio_payload["language"] = self.language
|
||||
|
||||
request_payload: Dict[str, Any] = {
|
||||
"model_name": self.model or self.DEFAULT_MODEL,
|
||||
"enable_itn": False,
|
||||
"enable_punc": True,
|
||||
"enable_ddc": False,
|
||||
"show_utterance": True,
|
||||
"result_type": "single",
|
||||
"vad_segment_duration": 3000,
|
||||
"end_window_size": 500,
|
||||
"force_to_speech_time": 1000,
|
||||
}
|
||||
|
||||
extra = dict(self.request_params)
|
||||
user_payload.update(self._as_dict(extra.pop("user", None)))
|
||||
audio_payload.update(self._as_dict(extra.pop("audio", None)))
|
||||
request_payload.update(self._as_dict(extra.pop("request", None)))
|
||||
request_payload.update(extra)
|
||||
|
||||
return {
|
||||
"user": user_payload,
|
||||
"audio": audio_payload,
|
||||
"request": request_payload,
|
||||
}
|
||||
|
||||
def _build_seed_start_request(self) -> bytes:
|
||||
payload = gzip.compress(json.dumps(self._build_seed_start_payload()).encode("utf-8"))
|
||||
frame = bytearray(
|
||||
self._build_seed_header(
|
||||
message_type=self._SEED_FULL_CLIENT_REQUEST,
|
||||
message_type_specific_flags=self._SEED_POS_SEQUENCE,
|
||||
)
|
||||
)
|
||||
frame.extend((1).to_bytes(4, "big", signed=True))
|
||||
frame.extend(len(payload).to_bytes(4, "big"))
|
||||
frame.extend(payload)
|
||||
return bytes(frame)
|
||||
|
||||
def _build_seed_audio_request(self, chunk: bytes, *, sequence: int, is_last: bool = False) -> bytes:
|
||||
payload = gzip.compress(chunk)
|
||||
frame = bytearray(
|
||||
self._build_seed_header(
|
||||
message_type=self._SEED_AUDIO_ONLY_REQUEST,
|
||||
message_type_specific_flags=self._SEED_NEG_WITH_SEQUENCE if is_last else self._SEED_POS_SEQUENCE,
|
||||
)
|
||||
)
|
||||
frame.extend(int(sequence).to_bytes(4, "big", signed=True))
|
||||
frame.extend(len(payload).to_bytes(4, "big"))
|
||||
frame.extend(payload)
|
||||
return bytes(frame)
|
||||
|
||||
@classmethod
|
||||
def _build_seed_header(
|
||||
cls,
|
||||
*,
|
||||
message_type: int,
|
||||
message_type_specific_flags: int,
|
||||
serial_method: int = _SEED_JSON,
|
||||
compression_type: int = _SEED_GZIP,
|
||||
reserved_data: int = 0x00,
|
||||
) -> bytes:
|
||||
header = bytearray()
|
||||
header.append((cls._SEED_PROTOCOL_VERSION << 4) | 0b0001)
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(reserved_data)
|
||||
return bytes(header)
|
||||
|
||||
@classmethod
|
||||
def _parse_seed_response(cls, response: bytes) -> Dict[str, Any]:
|
||||
header_size = response[0] & 0x0F
|
||||
message_type = response[1] >> 4
|
||||
message_type_specific_flags = response[1] & 0x0F
|
||||
serialization_method = response[2] >> 4
|
||||
compression_type = response[2] & 0x0F
|
||||
payload = response[header_size * 4 :]
|
||||
|
||||
result: Dict[str, Any] = {"is_last_package": False}
|
||||
payload_message: Any = None
|
||||
|
||||
if message_type_specific_flags & 0x01:
|
||||
result["payload_sequence"] = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload = payload[4:]
|
||||
|
||||
if message_type_specific_flags & 0x02:
|
||||
result["is_last_package"] = True
|
||||
|
||||
if message_type == cls._SEED_FULL_SERVER_RESPONSE:
|
||||
result["payload_size"] = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload_message = payload[4:]
|
||||
elif message_type == cls._SEED_SERVER_ACK:
|
||||
result["seq"] = int.from_bytes(payload[:4], "big", signed=True)
|
||||
if len(payload) >= 8:
|
||||
result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_message = payload[8:]
|
||||
elif message_type == cls._SEED_SERVER_ERROR_RESPONSE:
|
||||
result["code"] = int.from_bytes(payload[:4], "big", signed=False)
|
||||
result["payload_size"] = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_message = payload[8:]
|
||||
|
||||
if payload_message is None:
|
||||
return result
|
||||
if compression_type == cls._SEED_GZIP:
|
||||
payload_message = gzip.decompress(payload_message)
|
||||
if serialization_method == cls._SEED_JSON:
|
||||
payload_message = json.loads(payload_message.decode("utf-8"))
|
||||
elif serialization_method != cls._SEED_NO_SERIALIZATION:
|
||||
payload_message = payload_message.decode("utf-8")
|
||||
|
||||
result["payload_msg"] = payload_message
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _coerce_event(message: Any) -> Dict[str, Any]:
|
||||
if isinstance(message, dict):
|
||||
return message
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
loaded = json.loads(message)
|
||||
if isinstance(loaded, dict):
|
||||
return loaded
|
||||
except json.JSONDecodeError:
|
||||
return {"type": "raw", "message": message}
|
||||
return {"type": "raw", "message": str(message)}
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(payload: Dict[str, Any], keys: tuple[str, ...]) -> str:
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
for nested_key in ("message", "text", "transcript", "result", "delta"):
|
||||
nested = value.get(nested_key)
|
||||
if isinstance(nested, str) and nested.strip():
|
||||
return nested.strip()
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _coerce_bool(value: Any) -> Optional[bool]:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if normalized in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _as_dict(value: Any) -> Dict[str, Any]:
|
||||
if isinstance(value, dict):
|
||||
return dict(value)
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _drain_queue(queue: "asyncio.Queue[Any]") -> None:
|
||||
while True:
|
||||
try:
|
||||
queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
@@ -16,13 +16,18 @@ from runtime.ports import (
|
||||
TTSServiceSpec,
|
||||
)
|
||||
from providers.asr.buffered import BufferedASRService
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
from providers.asr.volcengine import VolcengineRealtimeASRService
|
||||
from providers.tts.dashscope import DashScopeTTSService
|
||||
from providers.llm.openai import MockLLMService, OpenAILLMService
|
||||
from providers.asr.openai_compatible import OpenAICompatibleASRService
|
||||
from providers.tts.openai_compatible import OpenAICompatibleTTSService
|
||||
from providers.tts.mock import MockTTSService
|
||||
from providers.tts.volcengine import VolcengineTTSService
|
||||
|
||||
_OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"}
|
||||
_DASHSCOPE_PROVIDERS = {"dashscope"}
|
||||
_VOLCENGINE_PROVIDERS = {"volcengine"}
|
||||
_SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS}
|
||||
|
||||
|
||||
@@ -31,8 +36,14 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
|
||||
_DEFAULT_DASHSCOPE_TTS_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
_DEFAULT_DASHSCOPE_TTS_MODEL = "qwen3-tts-flash-realtime"
|
||||
_DEFAULT_DASHSCOPE_ASR_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
_DEFAULT_DASHSCOPE_ASR_MODEL = "qwen3-asr-flash-realtime"
|
||||
_DEFAULT_OPENAI_COMPATIBLE_TTS_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
||||
_DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall"
|
||||
_DEFAULT_VOLCENGINE_TTS_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
|
||||
_DEFAULT_VOLCENGINE_TTS_RESOURCE_ID = "seed-tts-2.0"
|
||||
_DEFAULT_VOLCENGINE_ASR_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
|
||||
_DEFAULT_VOLCENGINE_ASR_MODEL = "bigmodel"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_provider(provider: Any) -> str:
|
||||
@@ -77,6 +88,19 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
speed=spec.speed,
|
||||
)
|
||||
|
||||
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
|
||||
return VolcengineTTSService(
|
||||
api_key=spec.api_key,
|
||||
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_TTS_URL,
|
||||
voice=spec.voice,
|
||||
model=spec.model,
|
||||
app_id=spec.app_id,
|
||||
resource_id=spec.resource_id or self._DEFAULT_VOLCENGINE_TTS_RESOURCE_ID,
|
||||
uid=spec.uid,
|
||||
sample_rate=spec.sample_rate,
|
||||
speed=spec.speed,
|
||||
)
|
||||
|
||||
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
|
||||
return OpenAICompatibleTTSService(
|
||||
api_key=spec.api_key,
|
||||
@@ -96,6 +120,30 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort:
|
||||
provider = self._normalize_provider(spec.provider)
|
||||
|
||||
if provider in _DASHSCOPE_PROVIDERS and spec.api_key:
|
||||
return DashScopeRealtimeASRService(
|
||||
api_key=spec.api_key,
|
||||
api_url=spec.api_url or self._DEFAULT_DASHSCOPE_ASR_REALTIME_URL,
|
||||
model=spec.model or self._DEFAULT_DASHSCOPE_ASR_MODEL,
|
||||
sample_rate=spec.sample_rate,
|
||||
language=spec.language,
|
||||
on_transcript=spec.on_transcript,
|
||||
)
|
||||
|
||||
if provider in _VOLCENGINE_PROVIDERS and spec.api_key:
|
||||
return VolcengineRealtimeASRService(
|
||||
api_key=spec.api_key,
|
||||
api_url=spec.api_url or self._DEFAULT_VOLCENGINE_ASR_REALTIME_URL,
|
||||
model=spec.model or self._DEFAULT_VOLCENGINE_ASR_MODEL,
|
||||
sample_rate=spec.sample_rate,
|
||||
language=spec.language,
|
||||
app_id=spec.app_id,
|
||||
resource_id=spec.resource_id,
|
||||
uid=spec.uid,
|
||||
request_params=spec.request_params,
|
||||
on_transcript=spec.on_transcript,
|
||||
)
|
||||
|
||||
if provider in _OPENAI_COMPATIBLE_PROVIDERS and spec.api_key:
|
||||
return OpenAICompatibleASRService(
|
||||
api_key=spec.api_key,
|
||||
@@ -103,6 +151,7 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory):
|
||||
model=spec.model or self._DEFAULT_OPENAI_COMPATIBLE_ASR_MODEL,
|
||||
sample_rate=spec.sample_rate,
|
||||
language=spec.language,
|
||||
enable_interim=spec.enable_interim,
|
||||
interim_interval_ms=spec.interim_interval_ms,
|
||||
min_audio_for_interim_ms=spec.min_audio_for_interim_ms,
|
||||
on_transcript=spec.on_transcript,
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
"""TTS providers."""
|
||||
|
||||
from providers.tts.volcengine import VolcengineTTSService
|
||||
|
||||
__all__ = ["VolcengineTTSService"]
|
||||
|
||||
219
engine/providers/tts/volcengine.py
Normal file
219
engine/providers/tts/volcengine.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Volcengine TTS service.
|
||||
|
||||
Uses Volcengine's unidirectional HTTP streaming TTS API and adapts streamed
|
||||
base64 audio chunks into engine-native ``TTSChunk`` events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import codecs
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import BaseTTSService, ServiceState, TTSChunk
|
||||
|
||||
|
||||
class VolcengineTTSService(BaseTTSService):
|
||||
"""Streaming TTS adapter for Volcengine's HTTP v3 API."""
|
||||
|
||||
DEFAULT_API_URL = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
|
||||
DEFAULT_RESOURCE_ID = "seed-tts-2.0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "zh_female_shuangkuaisisi_moon_bigtts",
|
||||
model: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
uid: Optional[str] = None,
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
self.api_key = api_key or os.getenv("VOLCENGINE_TTS_API_KEY") or os.getenv("TTS_API_KEY")
|
||||
self.api_url = api_url or os.getenv("VOLCENGINE_TTS_API_URL") or self.DEFAULT_API_URL
|
||||
self.model = str(model or os.getenv("VOLCENGINE_TTS_MODEL") or "").strip() or None
|
||||
self.app_id = app_id or os.getenv("VOLCENGINE_TTS_APP_ID") or os.getenv("TTS_APP_ID")
|
||||
self.resource_id = resource_id or os.getenv("VOLCENGINE_TTS_RESOURCE_ID") or self.DEFAULT_RESOURCE_ID
|
||||
self.uid = uid or os.getenv("VOLCENGINE_TTS_UID")
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._synthesis_lock = asyncio.Lock()
|
||||
self._pending_audio: list[bytes] = []
|
||||
|
||||
async def connect(self) -> None:
|
||||
if not self.api_key:
|
||||
raise ValueError("Volcengine TTS API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
if not self.app_id:
|
||||
raise ValueError("Volcengine TTS app_id not provided. Configure agent.tts.app_id in YAML.")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_read=None, sock_connect=15)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(
|
||||
"Volcengine TTS service ready: speaker={}, sample_rate={}, resource_id={}",
|
||||
self.voice,
|
||||
self.sample_rate,
|
||||
self.resource_id,
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._cancel_event.set()
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Volcengine TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
audio = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio += chunk.audio
|
||||
return audio
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
if not self._session:
|
||||
raise RuntimeError("Volcengine TTS service not connected")
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
async with self._synthesis_lock:
|
||||
self._cancel_event.clear()
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Api-App-Key": str(self.app_id),
|
||||
"X-Api-Access-Key": str(self.api_key),
|
||||
"X-Api-Resource-Id": str(self.resource_id),
|
||||
"X-Api-Request-Id": str(uuid.uuid4()),
|
||||
}
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": str(self.uid or self.app_id),
|
||||
},
|
||||
"req_params": {
|
||||
"text": text,
|
||||
"speaker": self.voice,
|
||||
"audio_params": {
|
||||
"format": "pcm",
|
||||
"sample_rate": self.sample_rate,
|
||||
"speech_rate": self._speech_rate_percent(self.speed),
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.model:
|
||||
payload["req_params"]["model"] = self.model
|
||||
|
||||
chunk_size = max(1, self.sample_rate * 2 // 10)
|
||||
audio_buffer = b""
|
||||
pending_chunk: Optional[bytes] = None
|
||||
|
||||
try:
|
||||
async with self._session.post(self.api_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Volcengine TTS error {response.status}: {error_text}")
|
||||
|
||||
async for audio_bytes in self._iter_audio_bytes(response):
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("Volcengine TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
audio_buffer += audio_bytes
|
||||
while len(audio_buffer) >= chunk_size:
|
||||
emitted = audio_buffer[:chunk_size]
|
||||
audio_buffer = audio_buffer[chunk_size:]
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = emitted
|
||||
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
if pending_chunk is not None:
|
||||
if audio_buffer:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=False)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(audio=pending_chunk, sample_rate=self.sample_rate, is_final=True)
|
||||
pending_chunk = None
|
||||
|
||||
if audio_buffer:
|
||||
yield TTSChunk(audio=audio_buffer, sample_rate=self.sample_rate, is_final=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Volcengine TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("Volcengine TTS synthesis error: {}", exc)
|
||||
raise
|
||||
|
||||
async def cancel(self) -> None:
|
||||
self._cancel_event.set()
|
||||
|
||||
async def _iter_audio_bytes(self, response: aiohttp.ClientResponse) -> AsyncIterator[bytes]:
|
||||
decoder = json.JSONDecoder()
|
||||
utf8_decoder = codecs.getincrementaldecoder("utf-8")()
|
||||
text_buffer = ""
|
||||
self._pending_audio.clear()
|
||||
|
||||
async for raw_chunk in response.content.iter_any():
|
||||
text_buffer += utf8_decoder.decode(raw_chunk)
|
||||
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
|
||||
while self._pending_audio:
|
||||
yield self._pending_audio.pop(0)
|
||||
|
||||
text_buffer += utf8_decoder.decode(b"", final=True)
|
||||
text_buffer = self._yield_audio_payloads(decoder, text_buffer)
|
||||
while self._pending_audio:
|
||||
yield self._pending_audio.pop(0)
|
||||
|
||||
def _yield_audio_payloads(self, decoder: json.JSONDecoder, text_buffer: str) -> str:
|
||||
while True:
|
||||
stripped = text_buffer.lstrip()
|
||||
if not stripped:
|
||||
return ""
|
||||
if len(stripped) != len(text_buffer):
|
||||
text_buffer = stripped
|
||||
|
||||
try:
|
||||
payload, idx = decoder.raw_decode(text_buffer)
|
||||
except json.JSONDecodeError:
|
||||
return text_buffer
|
||||
|
||||
text_buffer = text_buffer[idx:]
|
||||
audio = self._extract_audio_bytes(payload)
|
||||
if audio:
|
||||
self._pending_audio.append(audio)
|
||||
|
||||
def _extract_audio_bytes(self, payload: Any) -> bytes:
|
||||
if not isinstance(payload, dict):
|
||||
return b""
|
||||
|
||||
code = payload.get("code")
|
||||
if code not in (None, 0, 20000000):
|
||||
message = str(payload.get("message") or "unknown error")
|
||||
raise RuntimeError(f"Volcengine TTS stream error {code}: {message}")
|
||||
|
||||
encoded = payload.get("data")
|
||||
if isinstance(encoded, str) and encoded.strip():
|
||||
try:
|
||||
return base64.b64decode(encoded)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to decode Volcengine TTS audio chunk: {}", exc)
|
||||
return b""
|
||||
|
||||
@staticmethod
|
||||
def _speech_rate_percent(speed: float) -> int:
|
||||
clamped = max(0.5, min(2.0, float(speed or 1.0)))
|
||||
return int(round((clamped - 1.0) * 100))
|
||||
@@ -30,11 +30,14 @@ from providers.factory.default import DefaultRealtimeServiceFactory
|
||||
from runtime.conversation import ConversationManager, ConversationState
|
||||
from runtime.events import get_event_bus
|
||||
from runtime.ports import (
|
||||
ASRMode,
|
||||
ASRPort,
|
||||
ASRServiceSpec,
|
||||
LLMPort,
|
||||
LLMServiceSpec,
|
||||
OfflineASRPort,
|
||||
RealtimeServiceFactory,
|
||||
StreamingASRPort,
|
||||
TTSPort,
|
||||
TTSServiceSpec,
|
||||
)
|
||||
@@ -77,6 +80,7 @@ class DuplexPipeline:
|
||||
_ASR_DELTA_THROTTLE_MS = 500
|
||||
_LLM_DELTA_THROTTLE_MS = 80
|
||||
_ASR_CAPTURE_MAX_MS = 15000
|
||||
_ASR_STREAM_FINAL_TIMEOUT_MS = 800
|
||||
_OPENER_PRE_ROLL_MS = 180
|
||||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||||
"current_time": {
|
||||
@@ -317,6 +321,10 @@ class DuplexPipeline:
|
||||
self.llm_service = llm_service
|
||||
self.tts_service = tts_service
|
||||
self.asr_service = asr_service # Will be initialized in start()
|
||||
self._asr_mode: ASRMode = self._resolve_asr_mode(
|
||||
settings.asr_provider,
|
||||
getattr(asr_service, "mode", None),
|
||||
)
|
||||
self._service_factory = service_factory or DefaultRealtimeServiceFactory()
|
||||
self._knowledge_searcher = knowledge_searcher
|
||||
self._tool_resource_resolver = tool_resource_resolver
|
||||
@@ -324,6 +332,7 @@ class DuplexPipeline:
|
||||
|
||||
# Track last sent transcript to avoid duplicates
|
||||
self._last_sent_transcript = ""
|
||||
self._latest_asr_interim_text = ""
|
||||
self._pending_transcript_delta: str = ""
|
||||
self._last_transcript_delta_emit_ms: float = 0.0
|
||||
|
||||
@@ -588,7 +597,9 @@ class DuplexPipeline:
|
||||
},
|
||||
"asr": {
|
||||
"provider": asr_provider,
|
||||
"mode": self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")),
|
||||
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
|
||||
"enableInterim": self._asr_interim_enabled(),
|
||||
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
|
||||
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
|
||||
},
|
||||
@@ -782,11 +793,44 @@ class DuplexPipeline:
|
||||
return False
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_json_object(value: Any) -> Optional[Dict[str, Any]]:
|
||||
if isinstance(value, dict):
|
||||
return dict(value)
|
||||
if isinstance(value, str):
|
||||
raw = value.strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Ignoring invalid JSON object config: {}", raw[:120])
|
||||
return None
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_dashscope_tts_provider(provider: Any) -> bool:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
return normalized == "dashscope"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_asr_mode(provider: Any, raw_mode: Any = None) -> ASRMode:
|
||||
normalized_mode = str(raw_mode or "").strip().lower()
|
||||
if normalized_mode in {"offline", "streaming"}:
|
||||
return normalized_mode # type: ignore[return-value]
|
||||
normalized_provider = str(provider or "").strip().lower()
|
||||
if normalized_provider in {"dashscope", "volcengine"}:
|
||||
return "streaming"
|
||||
return "offline"
|
||||
|
||||
def _offline_asr(self) -> OfflineASRPort:
|
||||
return self.asr_service # type: ignore[return-value]
|
||||
|
||||
def _streaming_asr(self) -> StreamingASRPort:
|
||||
return self.asr_service # type: ignore[return-value]
|
||||
|
||||
@staticmethod
|
||||
def _default_llm_base_url(provider: Any) -> Optional[str]:
|
||||
normalized = str(provider or "").strip().lower()
|
||||
@@ -839,6 +883,20 @@ class DuplexPipeline:
|
||||
return self._runtime_barge_in_min_duration_ms
|
||||
return self._barge_in_min_duration_ms
|
||||
|
||||
def _asr_interim_enabled(self) -> bool:
|
||||
current_mode = self._asr_mode
|
||||
if not self.asr_service:
|
||||
current_mode = self._resolve_asr_mode(
|
||||
self._runtime_asr.get("provider") or settings.asr_provider,
|
||||
self._runtime_asr.get("mode"),
|
||||
)
|
||||
if current_mode != "offline":
|
||||
return True
|
||||
enabled = self._coerce_bool(self._runtime_asr.get("enableInterim"))
|
||||
if enabled is not None:
|
||||
return enabled
|
||||
return bool(settings.asr_enable_interim)
|
||||
|
||||
def _barge_in_silence_tolerance_frames(self) -> int:
|
||||
"""Convert silence tolerance from ms to frame count using current chunk size."""
|
||||
chunk_ms = max(1, settings.chunk_size_ms)
|
||||
@@ -922,6 +980,10 @@ class DuplexPipeline:
|
||||
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
|
||||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||||
tts_model = self._runtime_tts.get("model") or settings.tts_model
|
||||
tts_app_id = self._runtime_tts.get("appId") or settings.tts_app_id
|
||||
tts_resource_id = self._runtime_tts.get("resourceId") or settings.tts_resource_id
|
||||
tts_cluster = self._runtime_tts.get("cluster") or settings.tts_cluster
|
||||
tts_uid = self._runtime_tts.get("uid") or settings.tts_uid
|
||||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||||
tts_mode = self._resolved_dashscope_tts_mode()
|
||||
runtime_mode = str(self._runtime_tts.get("mode") or "").strip()
|
||||
@@ -937,6 +999,10 @@ class DuplexPipeline:
|
||||
api_url=str(tts_api_url).strip() if tts_api_url else None,
|
||||
voice=str(tts_voice),
|
||||
model=str(tts_model).strip() if tts_model else None,
|
||||
app_id=str(tts_app_id).strip() if tts_app_id else None,
|
||||
resource_id=str(tts_resource_id).strip() if tts_resource_id else None,
|
||||
cluster=str(tts_cluster).strip() if tts_cluster else None,
|
||||
uid=str(tts_uid).strip() if tts_uid else None,
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=tts_speed,
|
||||
mode=str(tts_mode),
|
||||
@@ -965,26 +1031,48 @@ class DuplexPipeline:
|
||||
asr_api_key = self._runtime_asr.get("apiKey")
|
||||
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
|
||||
asr_model = self._runtime_asr.get("model") or settings.asr_model
|
||||
asr_app_id = self._runtime_asr.get("appId") or settings.asr_app_id
|
||||
asr_resource_id = self._runtime_asr.get("resourceId") or settings.asr_resource_id
|
||||
asr_cluster = self._runtime_asr.get("cluster") or settings.asr_cluster
|
||||
asr_uid = self._runtime_asr.get("uid") or settings.asr_uid
|
||||
asr_request_params = self._coerce_json_object(self._runtime_asr.get("requestParams"))
|
||||
if asr_request_params is None:
|
||||
asr_request_params = self._coerce_json_object(settings.asr_request_params_json)
|
||||
asr_enable_interim = self._coerce_bool(self._runtime_asr.get("enableInterim"))
|
||||
if asr_enable_interim is None:
|
||||
asr_enable_interim = bool(settings.asr_enable_interim)
|
||||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||||
asr_mode = self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode"))
|
||||
|
||||
self.asr_service = self._service_factory.create_asr_service(
|
||||
ASRServiceSpec(
|
||||
provider=asr_provider,
|
||||
sample_rate=settings.sample_rate,
|
||||
mode=asr_mode,
|
||||
language="auto",
|
||||
api_key=str(asr_api_key).strip() if asr_api_key else None,
|
||||
api_url=str(asr_api_url).strip() if asr_api_url else None,
|
||||
model=str(asr_model).strip() if asr_model else None,
|
||||
app_id=str(asr_app_id).strip() if asr_app_id else None,
|
||||
resource_id=str(asr_resource_id).strip() if asr_resource_id else None,
|
||||
cluster=str(asr_cluster).strip() if asr_cluster else None,
|
||||
uid=str(asr_uid).strip() if asr_uid else None,
|
||||
request_params=asr_request_params,
|
||||
enable_interim=asr_enable_interim,
|
||||
interim_interval_ms=asr_interim_interval,
|
||||
min_audio_for_interim_ms=asr_min_audio_ms,
|
||||
on_transcript=self._on_transcript_callback,
|
||||
)
|
||||
)
|
||||
self._asr_mode = self._resolve_asr_mode(
|
||||
self._runtime_asr.get("provider") or settings.asr_provider,
|
||||
getattr(self.asr_service, "mode", self._runtime_asr.get("mode")),
|
||||
)
|
||||
|
||||
await self.asr_service.connect()
|
||||
|
||||
logger.info("DuplexPipeline services connected")
|
||||
logger.info("DuplexPipeline services connected (asr_mode={})", self._asr_mode)
|
||||
if not self._outbound_task or self._outbound_task.done():
|
||||
self._outbound_task = asyncio.create_task(self._outbound_loop())
|
||||
|
||||
@@ -1449,6 +1537,9 @@ class DuplexPipeline:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcription
|
||||
"""
|
||||
if not is_final and not self._asr_interim_enabled():
|
||||
return
|
||||
|
||||
# Avoid sending duplicate transcripts
|
||||
if text == self._last_sent_transcript and not is_final:
|
||||
return
|
||||
@@ -1457,6 +1548,7 @@ class DuplexPipeline:
|
||||
self._last_sent_transcript = text
|
||||
|
||||
if is_final:
|
||||
self._latest_asr_interim_text = ""
|
||||
self._pending_transcript_delta = ""
|
||||
self._last_transcript_delta_emit_ms = 0.0
|
||||
await self._send_event(
|
||||
@@ -1472,6 +1564,7 @@ class DuplexPipeline:
|
||||
logger.debug(f"Sent transcript (final): {text[:50]}...")
|
||||
return
|
||||
|
||||
self._latest_asr_interim_text = text
|
||||
self._pending_transcript_delta = text
|
||||
should_emit = (
|
||||
self._last_transcript_delta_emit_ms <= 0.0
|
||||
@@ -1495,14 +1588,16 @@ class DuplexPipeline:
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
self._latest_asr_interim_text = ""
|
||||
self.eou_detector.reset()
|
||||
self._asr_capture_active = False
|
||||
self._asr_capture_started_ms = 0.0
|
||||
self._pending_speech_audio = b""
|
||||
|
||||
# Clear ASR buffer. Interim starts only after ASR capture is activated.
|
||||
if hasattr(self.asr_service, 'clear_buffer'):
|
||||
self.asr_service.clear_buffer()
|
||||
if self._asr_mode == "streaming":
|
||||
self._streaming_asr().clear_utterance()
|
||||
else:
|
||||
self._offline_asr().clear_buffer()
|
||||
|
||||
logger.debug("User speech started")
|
||||
|
||||
@@ -1511,8 +1606,11 @@ class DuplexPipeline:
|
||||
if self._asr_capture_active:
|
||||
return
|
||||
|
||||
if hasattr(self.asr_service, 'start_interim_transcription'):
|
||||
await self.asr_service.start_interim_transcription()
|
||||
if self._asr_mode == "streaming":
|
||||
await self._streaming_asr().begin_utterance()
|
||||
else:
|
||||
if self._asr_interim_enabled():
|
||||
await self._offline_asr().start_interim_transcription()
|
||||
|
||||
# Prime ASR with a short pre-speech context window so the utterance
|
||||
# start isn't lost while waiting for VAD to transition to Speech.
|
||||
@@ -1545,24 +1643,22 @@ class DuplexPipeline:
|
||||
self._pending_speech_audio = b""
|
||||
return
|
||||
|
||||
# Add a tiny trailing silence tail to stabilize final-token decoding.
|
||||
if self._asr_final_tail_bytes > 0:
|
||||
final_tail = b"\x00" * self._asr_final_tail_bytes
|
||||
await self.asr_service.send_audio(final_tail)
|
||||
|
||||
# Stop interim transcriptions
|
||||
if hasattr(self.asr_service, 'stop_interim_transcription'):
|
||||
await self.asr_service.stop_interim_transcription()
|
||||
|
||||
# Get final transcription from ASR service
|
||||
user_text = ""
|
||||
|
||||
if hasattr(self.asr_service, 'get_final_transcription'):
|
||||
# SiliconFlow ASR - get final transcription
|
||||
user_text = await self.asr_service.get_final_transcription()
|
||||
elif hasattr(self.asr_service, 'get_and_clear_text'):
|
||||
# Buffered ASR - get accumulated text
|
||||
user_text = self.asr_service.get_and_clear_text()
|
||||
if self._asr_mode == "streaming":
|
||||
streaming_asr = self._streaming_asr()
|
||||
await streaming_asr.end_utterance()
|
||||
user_text = await streaming_asr.wait_for_final_transcription(
|
||||
timeout_ms=self._ASR_STREAM_FINAL_TIMEOUT_MS
|
||||
)
|
||||
if not user_text.strip():
|
||||
user_text = self._latest_asr_interim_text
|
||||
else:
|
||||
# Add a tiny trailing silence tail to stabilize final-token decoding.
|
||||
if self._asr_final_tail_bytes > 0:
|
||||
final_tail = b"\x00" * self._asr_final_tail_bytes
|
||||
await self.asr_service.send_audio(final_tail)
|
||||
await self._offline_asr().stop_interim_transcription()
|
||||
user_text = await self._offline_asr().get_final_transcription()
|
||||
|
||||
# Skip if no meaningful text
|
||||
if not user_text or not user_text.strip():
|
||||
@@ -1570,6 +1666,7 @@ class DuplexPipeline:
|
||||
# Reset for next utterance
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
self._latest_asr_interim_text = ""
|
||||
self._asr_capture_active = False
|
||||
self._asr_capture_started_ms = 0.0
|
||||
self._pending_speech_audio = b""
|
||||
@@ -1594,6 +1691,7 @@ class DuplexPipeline:
|
||||
# Clear buffers
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
self._latest_asr_interim_text = ""
|
||||
self._pending_transcript_delta = ""
|
||||
self._last_transcript_delta_emit_ms = 0.0
|
||||
self._asr_capture_active = False
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
"""Port interfaces for runtime integration boundaries."""
|
||||
|
||||
from runtime.ports.asr import ASRBufferControl, ASRInterimControl, ASRPort, ASRServiceSpec
|
||||
from runtime.ports.asr import (
|
||||
ASRMode,
|
||||
ASRPort,
|
||||
ASRServiceSpec,
|
||||
OfflineASRPort,
|
||||
StreamingASRPort,
|
||||
)
|
||||
from runtime.ports.control_plane import (
|
||||
AssistantRuntimeConfigProvider,
|
||||
ControlPlaneGateway,
|
||||
@@ -13,10 +19,11 @@ from runtime.ports.service_factory import RealtimeServiceFactory
|
||||
from runtime.ports.tts import TTSPort, TTSServiceSpec
|
||||
|
||||
__all__ = [
|
||||
"ASRMode",
|
||||
"ASRPort",
|
||||
"ASRServiceSpec",
|
||||
"ASRInterimControl",
|
||||
"ASRBufferControl",
|
||||
"OfflineASRPort",
|
||||
"StreamingASRPort",
|
||||
"AssistantRuntimeConfigProvider",
|
||||
"ControlPlaneGateway",
|
||||
"ConversationHistoryStore",
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator, Awaitable, Callable, Optional, Protocol
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Literal, Optional, Protocol
|
||||
|
||||
from providers.common.base import ASRResult
|
||||
|
||||
TranscriptCallback = Callable[[str, bool], Awaitable[None]]
|
||||
ASRMode = Literal["offline", "streaming"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -16,10 +17,17 @@ class ASRServiceSpec:
|
||||
|
||||
provider: str
|
||||
sample_rate: int
|
||||
mode: Optional[ASRMode] = None
|
||||
language: str = "auto"
|
||||
api_key: Optional[str] = None
|
||||
api_url: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
app_id: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
cluster: Optional[str] = None
|
||||
uid: Optional[str] = None
|
||||
request_params: Optional[Dict[str, Any]] = None
|
||||
enable_interim: bool = False
|
||||
interim_interval_ms: int = 500
|
||||
min_audio_for_interim_ms: int = 300
|
||||
on_transcript: Optional[TranscriptCallback] = None
|
||||
@@ -28,6 +36,8 @@ class ASRServiceSpec:
|
||||
class ASRPort(Protocol):
|
||||
"""Port for speech recognition providers."""
|
||||
|
||||
mode: ASRMode
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to ASR provider."""
|
||||
|
||||
@@ -41,18 +51,16 @@ class ASRPort(Protocol):
|
||||
"""Stream partial/final recognition results."""
|
||||
|
||||
|
||||
class ASRInterimControl(Protocol):
|
||||
"""Optional extension for explicit interim transcription control."""
|
||||
class OfflineASRPort(ASRPort, Protocol):
|
||||
"""Port for offline/buffered ASR providers."""
|
||||
|
||||
mode: Literal["offline"]
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
"""Start interim transcription loop if supported."""
|
||||
"""Start interim transcription loop."""
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
"""Stop interim transcription loop if supported."""
|
||||
|
||||
|
||||
class ASRBufferControl(Protocol):
|
||||
"""Optional extension for explicit ASR buffer lifecycle control."""
|
||||
"""Stop interim transcription loop."""
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""Clear provider-side ASR buffer."""
|
||||
@@ -62,3 +70,21 @@ class ASRBufferControl(Protocol):
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
"""Return buffered text and clear internal state."""
|
||||
|
||||
|
||||
class StreamingASRPort(ASRPort, Protocol):
|
||||
"""Port for streaming ASR providers."""
|
||||
|
||||
mode: Literal["streaming"]
|
||||
|
||||
async def begin_utterance(self) -> None:
|
||||
"""Start a new utterance stream."""
|
||||
|
||||
async def end_utterance(self) -> None:
|
||||
"""Signal end of current utterance stream."""
|
||||
|
||||
async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str:
|
||||
"""Wait for final transcript after utterance end."""
|
||||
|
||||
def clear_utterance(self) -> None:
|
||||
"""Reset utterance-local state."""
|
||||
|
||||
@@ -19,6 +19,10 @@ class TTSServiceSpec:
|
||||
api_key: Optional[str] = None
|
||||
api_url: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
app_id: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
cluster: Optional[str] = None
|
||||
uid: Optional[str] = None
|
||||
mode: str = "commit"
|
||||
|
||||
|
||||
|
||||
71
engine/tests/test_asr_factory_modes.py
Normal file
71
engine/tests/test_asr_factory_modes.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from providers.asr.buffered import BufferedASRService
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
from providers.asr.openai_compatible import OpenAICompatibleASRService
|
||||
from providers.asr.volcengine import VolcengineRealtimeASRService
|
||||
from providers.factory.default import DefaultRealtimeServiceFactory
|
||||
from runtime.ports import ASRServiceSpec
|
||||
|
||||
|
||||
def test_create_asr_service_dashscope_returns_streaming_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_asr_service(
|
||||
ASRServiceSpec(
|
||||
provider="dashscope",
|
||||
mode="streaming",
|
||||
sample_rate=16000,
|
||||
api_key="test-key",
|
||||
model="qwen3-asr-flash-realtime",
|
||||
)
|
||||
)
|
||||
assert isinstance(service, DashScopeRealtimeASRService)
|
||||
assert service.mode == "streaming"
|
||||
|
||||
|
||||
def test_create_asr_service_openai_compatible_returns_offline_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_asr_service(
|
||||
ASRServiceSpec(
|
||||
provider="openai_compatible",
|
||||
sample_rate=16000,
|
||||
api_key="test-key",
|
||||
model="FunAudioLLM/SenseVoiceSmall",
|
||||
)
|
||||
)
|
||||
assert isinstance(service, OpenAICompatibleASRService)
|
||||
assert service.mode == "offline"
|
||||
assert service.enable_interim is False
|
||||
|
||||
|
||||
def test_create_asr_service_volcengine_returns_streaming_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_asr_service(
|
||||
ASRServiceSpec(
|
||||
provider="volcengine",
|
||||
mode="streaming",
|
||||
sample_rate=16000,
|
||||
api_key="test-key",
|
||||
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
|
||||
model="bigmodel",
|
||||
app_id="app-1",
|
||||
uid="caller-1",
|
||||
request_params={"end_window_size": 800},
|
||||
)
|
||||
)
|
||||
assert isinstance(service, VolcengineRealtimeASRService)
|
||||
assert service.mode == "streaming"
|
||||
assert service.protocol == "seed"
|
||||
assert service.app_id == "app-1"
|
||||
assert service.uid == "caller-1"
|
||||
assert service.request_params["end_window_size"] == 800
|
||||
|
||||
|
||||
def test_create_asr_service_fallback_buffered_for_unsupported_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_asr_service(
|
||||
ASRServiceSpec(
|
||||
provider="unknown_provider",
|
||||
sample_rate=16000,
|
||||
)
|
||||
)
|
||||
assert isinstance(service, BufferedASRService)
|
||||
assert service.mode == "offline"
|
||||
@@ -227,6 +227,62 @@ async def test_with_backend_url_uses_backend_for_assistant_config(monkeypatch, t
|
||||
assert payload["assistant"]["systemPrompt"] == "backend prompt"
|
||||
|
||||
|
||||
def test_translate_agent_schema_maps_volcengine_fields():
|
||||
payload = {
|
||||
"agent": {
|
||||
"tts": {
|
||||
"provider": "volcengine",
|
||||
"api_key": "tts-key",
|
||||
"api_url": "https://openspeech.bytedance.com/api/v3/tts/unidirectional",
|
||||
"app_id": "app-123",
|
||||
"resource_id": "seed-tts-2.0",
|
||||
"uid": "caller-1",
|
||||
"voice": "zh_female_shuangkuaisisi_moon_bigtts",
|
||||
"speed": 1.1,
|
||||
},
|
||||
"asr": {
|
||||
"provider": "volcengine",
|
||||
"api_key": "asr-key",
|
||||
"api_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
|
||||
"model": "bigmodel",
|
||||
"app_id": "app-123",
|
||||
"resource_id": "volc.bigasr.sauc.duration",
|
||||
"uid": "caller-1",
|
||||
"request_params": {
|
||||
"end_window_size": 800,
|
||||
"force_to_speech_time": 1000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
translated = LocalYamlAssistantConfigAdapter._translate_agent_schema("assistant_demo", payload)
|
||||
assert translated is not None
|
||||
assert translated["services"]["tts"] == {
|
||||
"provider": "volcengine",
|
||||
"apiKey": "tts-key",
|
||||
"baseUrl": "https://openspeech.bytedance.com/api/v3/tts/unidirectional",
|
||||
"voice": "zh_female_shuangkuaisisi_moon_bigtts",
|
||||
"appId": "app-123",
|
||||
"resourceId": "seed-tts-2.0",
|
||||
"uid": "caller-1",
|
||||
"speed": 1.1,
|
||||
}
|
||||
assert translated["services"]["asr"] == {
|
||||
"provider": "volcengine",
|
||||
"model": "bigmodel",
|
||||
"apiKey": "asr-key",
|
||||
"baseUrl": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
|
||||
"appId": "app-123",
|
||||
"resourceId": "volc.bigasr.sauc.duration",
|
||||
"uid": "caller-1",
|
||||
"requestParams": {
|
||||
"end_window_size": 800,
|
||||
"force_to_speech_time": 1000,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backend_mode_disabled_uses_local_assistant_config_even_with_url(monkeypatch, tmp_path):
|
||||
class _FailIfCalledClientSession:
|
||||
@@ -282,7 +338,7 @@ async def test_local_yaml_adapter_rejects_path_traversal_like_assistant_id(tmp_p
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
|
||||
async def test_local_yaml_translates_agent_schema_with_asr_interim_flag(tmp_path):
|
||||
config_dir = tmp_path / "assistants"
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
(config_dir / "default.yaml").write_text(
|
||||
@@ -305,6 +361,7 @@ async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
|
||||
" model: asr-model",
|
||||
" api_key: sk-asr",
|
||||
" api_url: https://asr.example.com/v1/audio/transcriptions",
|
||||
" enable_interim: false",
|
||||
" duplex:",
|
||||
" system_prompt: You are test assistant",
|
||||
]
|
||||
@@ -321,4 +378,5 @@ async def test_local_yaml_translates_agent_schema_to_runtime_services(tmp_path):
|
||||
assert services.get("llm", {}).get("apiKey") == "sk-llm"
|
||||
assert services.get("tts", {}).get("apiKey") == "sk-tts"
|
||||
assert services.get("asr", {}).get("apiKey") == "sk-asr"
|
||||
assert services.get("asr", {}).get("enableInterim") is False
|
||||
assert assistant.get("systemPrompt") == "You are test assistant"
|
||||
|
||||
67
engine/tests/test_dashscope_asr_provider.py
Normal file
67
engine/tests/test_dashscope_asr_provider.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.asr.dashscope import DashScopeRealtimeASRService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_asr_interim_event_emits_interim_transcript():
|
||||
received = []
|
||||
|
||||
async def _on_transcript(text: str, is_final: bool) -> None:
|
||||
received.append((text, is_final))
|
||||
|
||||
service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript)
|
||||
service._loop = asyncio.get_running_loop()
|
||||
service._running = True
|
||||
|
||||
service._on_ws_event(
|
||||
{
|
||||
"type": "conversation.item.input_audio_transcription.text",
|
||||
"stash": "你好世界",
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
result = service._transcript_queue.get_nowait()
|
||||
assert result.text == "你好世界"
|
||||
assert result.is_final is False
|
||||
assert received == [("你好世界", False)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_asr_final_event_emits_final_transcript_and_final_queue():
|
||||
received = []
|
||||
|
||||
async def _on_transcript(text: str, is_final: bool) -> None:
|
||||
received.append((text, is_final))
|
||||
|
||||
service = DashScopeRealtimeASRService(api_key="test-key", on_transcript=_on_transcript)
|
||||
service._loop = asyncio.get_running_loop()
|
||||
service._running = True
|
||||
service._audio_sent_in_utterance = True
|
||||
|
||||
service._on_ws_event(
|
||||
{
|
||||
"type": "conversation.item.input_audio_transcription.completed",
|
||||
"transcript": "最终识别结果",
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
result = service._transcript_queue.get_nowait()
|
||||
assert result.text == "最终识别结果"
|
||||
assert result.is_final is True
|
||||
assert service._final_queue.get_nowait() == "最终识别结果"
|
||||
assert received == [("最终识别结果", True)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_wait_for_final_falls_back_to_latest_interim_on_timeout():
|
||||
service = DashScopeRealtimeASRService(api_key="test-key")
|
||||
service._audio_sent_in_utterance = True
|
||||
service._last_interim_text = "部分结果"
|
||||
|
||||
text = await service.wait_for_final_transcription(timeout_ms=10)
|
||||
assert text == "部分结果"
|
||||
260
engine/tests/test_duplex_asr_modes.py
Normal file
260
engine/tests/test_duplex_asr_modes.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from runtime.pipeline.duplex import DuplexPipeline
|
||||
|
||||
|
||||
class _DummySileroVAD:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process_audio(self, _pcm: bytes) -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
class _DummyVADProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process(self, _speech_prob: float):
|
||||
return "Silence", 0.0
|
||||
|
||||
|
||||
class _DummyEouDetector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_speaking = True
|
||||
|
||||
def process(self, _vad_status: str, force_eligible: bool = False) -> bool:
|
||||
_ = force_eligible
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
self.is_speaking = False
|
||||
|
||||
|
||||
class _FakeTransport:
|
||||
async def send_event(self, _event: Dict[str, Any]) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, _audio: bytes) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeStreamingASR:
|
||||
mode = "streaming"
|
||||
|
||||
def __init__(self):
|
||||
self.begin_calls = 0
|
||||
self.end_calls = 0
|
||||
self.wait_calls = 0
|
||||
self.sent_audio: List[bytes] = []
|
||||
self.wait_text = ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
return None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
self.sent_audio.append(audio)
|
||||
|
||||
async def receive_transcripts(self):
|
||||
if False:
|
||||
yield None
|
||||
|
||||
async def begin_utterance(self) -> None:
|
||||
self.begin_calls += 1
|
||||
|
||||
async def end_utterance(self) -> None:
|
||||
self.end_calls += 1
|
||||
|
||||
async def wait_for_final_transcription(self, timeout_ms: int = 800) -> str:
|
||||
_ = timeout_ms
|
||||
self.wait_calls += 1
|
||||
return self.wait_text
|
||||
|
||||
def clear_utterance(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeOfflineASR:
|
||||
mode = "offline"
|
||||
|
||||
def __init__(self):
|
||||
self.start_interim_calls = 0
|
||||
self.stop_interim_calls = 0
|
||||
self.sent_audio: List[bytes] = []
|
||||
self.final_text = "offline final"
|
||||
|
||||
async def connect(self) -> None:
|
||||
return None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
self.sent_audio.append(audio)
|
||||
|
||||
async def receive_transcripts(self):
|
||||
if False:
|
||||
yield None
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
self.start_interim_calls += 1
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
self.stop_interim_calls += 1
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
return self.final_text
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
return None
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
return self.final_text
|
||||
|
||||
|
||||
def _build_pipeline(monkeypatch, asr_service):
|
||||
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
|
||||
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
|
||||
monkeypatch.setattr("runtime.pipeline.duplex.EouDetector", _DummyEouDetector)
|
||||
return DuplexPipeline(
|
||||
transport=_FakeTransport(),
|
||||
session_id="asr_mode_test",
|
||||
asr_service=asr_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_asr_capture_uses_streaming_begin(monkeypatch):
|
||||
asr = _FakeStreamingASR()
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "streaming"
|
||||
pipeline._pending_speech_audio = b"\x00" * 320
|
||||
pipeline._pre_speech_buffer = b"\x00" * 640
|
||||
|
||||
await pipeline._start_asr_capture()
|
||||
|
||||
assert asr.begin_calls == 1
|
||||
assert asr.sent_audio
|
||||
assert pipeline._asr_capture_active is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_asr_capture_uses_offline_interim_control_when_enabled(monkeypatch):
|
||||
asr = _FakeOfflineASR()
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "offline"
|
||||
pipeline._runtime_asr["enableInterim"] = True
|
||||
pipeline._pending_speech_audio = b"\x00" * 320
|
||||
pipeline._pre_speech_buffer = b"\x00" * 640
|
||||
|
||||
await pipeline._start_asr_capture()
|
||||
|
||||
assert asr.start_interim_calls == 1
|
||||
assert asr.sent_audio
|
||||
assert pipeline._asr_capture_active is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_asr_capture_skips_offline_interim_control_when_disabled(monkeypatch):
|
||||
asr = _FakeOfflineASR()
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "offline"
|
||||
pipeline._runtime_asr["enableInterim"] = False
|
||||
pipeline._pending_speech_audio = b"\x00" * 320
|
||||
pipeline._pre_speech_buffer = b"\x00" * 640
|
||||
|
||||
await pipeline._start_asr_capture()
|
||||
|
||||
assert asr.start_interim_calls == 0
|
||||
assert asr.sent_audio
|
||||
assert pipeline._asr_capture_active is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offline_interim_callback_ignored_when_disabled(monkeypatch):
|
||||
asr = _FakeOfflineASR()
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "offline"
|
||||
pipeline._runtime_asr["enableInterim"] = False
|
||||
|
||||
captured_events = []
|
||||
captured_deltas = []
|
||||
|
||||
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
||||
_ = priority
|
||||
captured_events.append(event)
|
||||
|
||||
async def _capture_delta(text: str):
|
||||
captured_deltas.append(text)
|
||||
|
||||
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
||||
monkeypatch.setattr(pipeline, "_emit_transcript_delta", _capture_delta)
|
||||
|
||||
await pipeline._on_transcript_callback("ignored interim", is_final=False)
|
||||
|
||||
assert captured_events == []
|
||||
assert captured_deltas == []
|
||||
assert pipeline._latest_asr_interim_text == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offline_final_callback_emits_when_interim_disabled(monkeypatch):
|
||||
asr = _FakeOfflineASR()
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "offline"
|
||||
pipeline._runtime_asr["enableInterim"] = False
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
||||
_ = priority
|
||||
captured_events.append(event)
|
||||
|
||||
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
||||
|
||||
await pipeline._on_transcript_callback("final only", is_final=True)
|
||||
|
||||
assert any(event.get("type") == "transcript.final" for event in captured_events)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_eou_falls_back_to_latest_interim(monkeypatch):
|
||||
asr = _FakeStreamingASR()
|
||||
asr.wait_text = ""
|
||||
pipeline = _build_pipeline(monkeypatch, asr)
|
||||
pipeline._asr_mode = "streaming"
|
||||
pipeline._asr_capture_active = True
|
||||
pipeline._latest_asr_interim_text = "fallback interim text"
|
||||
await pipeline.conversation.start_user_turn()
|
||||
|
||||
captured_events = []
|
||||
captured_turns = []
|
||||
|
||||
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
||||
_ = priority
|
||||
captured_events.append(event)
|
||||
|
||||
async def _noop_stop_current_speech() -> None:
|
||||
return None
|
||||
|
||||
async def _capture_turn(user_text: str, *args, **kwargs) -> None:
|
||||
_ = (args, kwargs)
|
||||
captured_turns.append(user_text)
|
||||
|
||||
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
||||
monkeypatch.setattr(pipeline, "_stop_current_speech", _noop_stop_current_speech)
|
||||
monkeypatch.setattr(pipeline, "_handle_turn", _capture_turn)
|
||||
|
||||
await pipeline._on_end_of_utterance()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert asr.end_calls == 1
|
||||
assert asr.wait_calls == 1
|
||||
assert captured_turns == ["fallback interim text"]
|
||||
assert any(event.get("type") == "transcript.final" for event in captured_events)
|
||||
@@ -52,9 +52,33 @@ class _FakeTTS:
|
||||
|
||||
|
||||
class _FakeASR:
|
||||
mode = "offline"
|
||||
|
||||
async def connect(self) -> None:
|
||||
return None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, _audio: bytes) -> None:
|
||||
return None
|
||||
|
||||
async def receive_transcripts(self):
|
||||
if False:
|
||||
yield None
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
return None
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(self, rounds: List[List[LLMStreamEvent]]):
|
||||
|
||||
45
engine/tests/test_tts_factory_modes.py
Normal file
45
engine/tests/test_tts_factory_modes.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from providers.factory.default import DefaultRealtimeServiceFactory
|
||||
from providers.tts.mock import MockTTSService
|
||||
from providers.tts.openai_compatible import OpenAICompatibleTTSService
|
||||
from providers.tts.volcengine import VolcengineTTSService
|
||||
from runtime.ports import TTSServiceSpec
|
||||
|
||||
|
||||
def test_create_tts_service_volcengine_returns_native_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_tts_service(
|
||||
TTSServiceSpec(
|
||||
provider="volcengine",
|
||||
api_key="test-key",
|
||||
app_id="app-1",
|
||||
resource_id="seed-tts-2.0",
|
||||
voice="zh_female_shuangkuaisisi_moon_bigtts",
|
||||
sample_rate=16000,
|
||||
)
|
||||
)
|
||||
assert isinstance(service, VolcengineTTSService)
|
||||
|
||||
|
||||
def test_create_tts_service_openai_compatible_returns_provider():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_tts_service(
|
||||
TTSServiceSpec(
|
||||
provider="openai_compatible",
|
||||
api_key="test-key",
|
||||
voice="anna",
|
||||
sample_rate=16000,
|
||||
)
|
||||
)
|
||||
assert isinstance(service, OpenAICompatibleTTSService)
|
||||
|
||||
|
||||
def test_create_tts_service_fallbacks_to_mock_without_key():
|
||||
factory = DefaultRealtimeServiceFactory()
|
||||
service = factory.create_tts_service(
|
||||
TTSServiceSpec(
|
||||
provider="volcengine",
|
||||
voice="anna",
|
||||
sample_rate=16000,
|
||||
)
|
||||
)
|
||||
assert isinstance(service, MockTTSService)
|
||||
86
engine/tests/test_volcengine_asr_provider.py
Normal file
86
engine/tests/test_volcengine_asr_provider.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import gzip
|
||||
import json
|
||||
|
||||
from providers.asr.volcengine import VolcengineRealtimeASRService
|
||||
|
||||
|
||||
def test_volcengine_seed_protocol_defaults_and_headers():
|
||||
service = VolcengineRealtimeASRService(
|
||||
api_key="access-token",
|
||||
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
|
||||
app_id="app-1",
|
||||
uid="caller-1",
|
||||
)
|
||||
|
||||
assert service.protocol == "seed"
|
||||
assert service.resource_id == "volc.bigasr.sauc.duration"
|
||||
|
||||
headers = service._build_seed_headers("req-1")
|
||||
assert headers == {
|
||||
"X-Api-App-Key": "app-1",
|
||||
"X-Api-Access-Key": "access-token",
|
||||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||
"X-Api-Request-Id": "req-1",
|
||||
}
|
||||
|
||||
|
||||
def test_volcengine_seed_start_payload_merges_request_params():
|
||||
service = VolcengineRealtimeASRService(
|
||||
api_key="access-token",
|
||||
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
|
||||
app_id="app-1",
|
||||
uid="caller-1",
|
||||
language="zh-CN",
|
||||
request_params={
|
||||
"request": {
|
||||
"end_window_size": 800,
|
||||
"force_to_speech_time": 1000,
|
||||
"context": "{\"hotwords\":[{\"word\":\"doubao\"}]}",
|
||||
},
|
||||
"audio": {"codec": "raw"},
|
||||
},
|
||||
)
|
||||
|
||||
payload = service._build_seed_start_payload()
|
||||
assert payload["user"] == {"uid": "caller-1"}
|
||||
assert payload["audio"] == {
|
||||
"format": "pcm",
|
||||
"rate": 16000,
|
||||
"bits": 16,
|
||||
"channels": 1,
|
||||
"codec": "raw",
|
||||
"language": "zh-CN",
|
||||
}
|
||||
assert payload["request"]["model_name"] == "bigmodel"
|
||||
assert payload["request"]["end_window_size"] == 800
|
||||
assert payload["request"]["force_to_speech_time"] == 1000
|
||||
assert payload["request"]["context"] == "{\"hotwords\":[{\"word\":\"doubao\"}]}"
|
||||
|
||||
|
||||
def test_volcengine_seed_start_request_encodes_gzip_json_payload():
|
||||
service = VolcengineRealtimeASRService(
|
||||
api_key="access-token",
|
||||
api_url="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
|
||||
app_id="app-1",
|
||||
uid="caller-1",
|
||||
)
|
||||
|
||||
frame = service._build_seed_start_request()
|
||||
assert frame[0] == 0x11
|
||||
assert frame[1] == 0x11
|
||||
|
||||
payload_length = int.from_bytes(frame[8:12], "big")
|
||||
payload = json.loads(gzip.decompress(frame[12 : 12 + payload_length]).decode("utf-8"))
|
||||
assert payload["user"]["uid"] == "caller-1"
|
||||
assert payload["request"]["model_name"] == "bigmodel"
|
||||
|
||||
|
||||
def test_volcengine_gateway_protocol_keeps_model_query():
|
||||
service = VolcengineRealtimeASRService(
|
||||
api_key="access-token",
|
||||
api_url="wss://ai-gateway.vei.volces.com/v1/realtime",
|
||||
model="bigmodel",
|
||||
)
|
||||
|
||||
assert service.protocol == "gateway"
|
||||
assert service.api_url == "wss://ai-gateway.vei.volces.com/v1/realtime?model=bigmodel"
|
||||
@@ -259,6 +259,7 @@ export const AssistantsPage: React.FC = () => {
|
||||
speed: 1,
|
||||
hotwords: [],
|
||||
tools: [],
|
||||
asrInterimEnabled: false,
|
||||
botCannotBeInterrupted: false,
|
||||
interruptionSensitivity: 180,
|
||||
configMode: 'platform',
|
||||
@@ -1358,6 +1359,41 @@ export const AssistantsPage: React.FC = () => {
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<label className="text-sm font-medium text-white flex items-center">
|
||||
<Mic className="w-4 h-4 mr-2 text-primary"/> 离线 ASR 中间结果
|
||||
</label>
|
||||
<div className="inline-flex rounded-lg border border-white/10 bg-white/5 p-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => updateAssistant('asrInterimEnabled', false)}
|
||||
className={`px-3 py-1 text-xs rounded-md transition-colors ${
|
||||
selectedAssistant.asrInterimEnabled === true
|
||||
? 'text-muted-foreground hover:text-foreground'
|
||||
: 'bg-primary text-primary-foreground shadow-sm'
|
||||
}`}
|
||||
>
|
||||
关闭
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => updateAssistant('asrInterimEnabled', true)}
|
||||
className={`px-3 py-1 text-xs rounded-md transition-colors ${
|
||||
selectedAssistant.asrInterimEnabled === true
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground'
|
||||
}`}
|
||||
>
|
||||
开启
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
仅影响离线 ASR 模式(OpenAI Compatible / buffered)。默认关闭。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<label className="text-sm font-medium text-white flex items-center">
|
||||
|
||||
@@ -87,6 +87,7 @@ const mapAssistant = (raw: AnyRecord): Assistant => ({
|
||||
speed: Number(readField(raw, ['speed'], 1)),
|
||||
hotwords: readField(raw, ['hotwords'], []),
|
||||
tools: normalizeToolIdList(readField(raw, ['tools'], [])),
|
||||
asrInterimEnabled: Boolean(readField(raw, ['asrInterimEnabled', 'asr_interim_enabled'], false)),
|
||||
botCannotBeInterrupted: Boolean(readField(raw, ['botCannotBeInterrupted', 'bot_cannot_be_interrupted'], false)),
|
||||
interruptionSensitivity: Number(readField(raw, ['interruptionSensitivity', 'interruption_sensitivity'], 500)),
|
||||
configMode: readField(raw, ['configMode', 'config_mode'], 'platform') as 'platform' | 'dify' | 'fastgpt' | 'none',
|
||||
@@ -284,6 +285,7 @@ export const createAssistant = async (data: Partial<Assistant>): Promise<Assista
|
||||
speed: data.speed ?? 1,
|
||||
hotwords: data.hotwords || [],
|
||||
tools: normalizeToolIdList(data.tools || []),
|
||||
asrInterimEnabled: data.asrInterimEnabled ?? false,
|
||||
botCannotBeInterrupted: data.botCannotBeInterrupted ?? false,
|
||||
interruptionSensitivity: data.interruptionSensitivity ?? 500,
|
||||
configMode: data.configMode || 'platform',
|
||||
@@ -316,6 +318,7 @@ export const updateAssistant = async (id: string, data: Partial<Assistant>): Pro
|
||||
speed: data.speed,
|
||||
hotwords: data.hotwords,
|
||||
tools: data.tools === undefined ? undefined : normalizeToolIdList(data.tools),
|
||||
asrInterimEnabled: data.asrInterimEnabled,
|
||||
botCannotBeInterrupted: data.botCannotBeInterrupted,
|
||||
interruptionSensitivity: data.interruptionSensitivity,
|
||||
configMode: data.configMode,
|
||||
|
||||
@@ -19,6 +19,7 @@ export interface Assistant {
|
||||
speed: number;
|
||||
hotwords: string[];
|
||||
tools?: string[]; // IDs of enabled tools
|
||||
asrInterimEnabled?: boolean;
|
||||
botCannotBeInterrupted?: boolean;
|
||||
interruptionSensitivity?: number; // In ms
|
||||
configMode?: 'platform' | 'dify' | 'fastgpt' | 'none';
|
||||
|
||||
Reference in New Issue
Block a user