Add ASR interim results support in Assistant model and API
- Introduced `asr_interim_enabled` field in the Assistant model to control interim ASR results. - Updated AssistantBase and AssistantUpdate schemas to include the new field. - Modified the database schema to add the `asr_interim_enabled` column. - Enhanced runtime metadata to reflect interim ASR settings. - Updated API endpoints and tests to validate the new functionality. - Adjusted documentation to include details about interim ASR results configuration.
This commit is contained in:
@@ -127,6 +127,7 @@ class Assistant(Base):
|
||||
speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
hotwords: Mapped[dict] = mapped_column(JSON, default=list)
|
||||
tools: Mapped[dict] = mapped_column(JSON, default=list)
|
||||
asr_interim_enabled: Mapped[bool] = mapped_column(default=False)
|
||||
bot_cannot_be_interrupted: Mapped[bool] = mapped_column(default=False)
|
||||
interruption_sensitivity: Mapped[int] = mapped_column(Integer, default=500)
|
||||
config_mode: Mapped[str] = mapped_column(String(32), default="platform")
|
||||
|
||||
@@ -126,6 +126,9 @@ def _ensure_assistant_schema(db: Session) -> None:
|
||||
if "manual_opener_tool_calls" not in columns:
|
||||
db.execute(text("ALTER TABLE assistants ADD COLUMN manual_opener_tool_calls JSON"))
|
||||
altered = True
|
||||
if "asr_interim_enabled" not in columns:
|
||||
db.execute(text("ALTER TABLE assistants ADD COLUMN asr_interim_enabled BOOLEAN DEFAULT 0"))
|
||||
altered = True
|
||||
|
||||
if altered:
|
||||
db.commit()
|
||||
@@ -317,6 +320,9 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
|
||||
else:
|
||||
warnings.append(f"LLM model not found: {assistant.llm_model_id}")
|
||||
|
||||
asr_runtime: Dict[str, Any] = {
|
||||
"enableInterim": bool(assistant.asr_interim_enabled),
|
||||
}
|
||||
if assistant.asr_model_id:
|
||||
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
|
||||
if asr:
|
||||
@@ -326,14 +332,15 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
|
||||
asr_provider = "openai_compatible"
|
||||
else:
|
||||
asr_provider = "buffered"
|
||||
metadata["services"]["asr"] = {
|
||||
asr_runtime.update({
|
||||
"provider": asr_provider,
|
||||
"model": asr.model_name or asr.name,
|
||||
"apiKey": asr.api_key if asr_provider in {"openai_compatible", "dashscope"} else None,
|
||||
"baseUrl": asr.base_url if asr_provider in {"openai_compatible", "dashscope"} else None,
|
||||
}
|
||||
})
|
||||
else:
|
||||
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
|
||||
metadata["services"]["asr"] = asr_runtime
|
||||
|
||||
if not assistant.voice_output_enabled:
|
||||
metadata["services"]["tts"] = {"enabled": False}
|
||||
@@ -437,6 +444,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
"speed": assistant.speed,
|
||||
"hotwords": assistant.hotwords or [],
|
||||
"tools": _normalize_assistant_tool_ids(assistant.tools),
|
||||
"asrInterimEnabled": bool(assistant.asr_interim_enabled),
|
||||
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
|
||||
"interruptionSensitivity": assistant.interruption_sensitivity,
|
||||
"configMode": assistant.config_mode,
|
||||
@@ -457,6 +465,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
||||
"firstTurnMode": "first_turn_mode",
|
||||
"manualOpenerToolCalls": "manual_opener_tool_calls",
|
||||
"interruptionSensitivity": "interruption_sensitivity",
|
||||
"asrInterimEnabled": "asr_interim_enabled",
|
||||
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
|
||||
"configMode": "config_mode",
|
||||
"voiceOutputEnabled": "voice_output_enabled",
|
||||
@@ -651,6 +660,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
speed=data.speed,
|
||||
hotwords=data.hotwords,
|
||||
tools=_normalize_assistant_tool_ids(data.tools),
|
||||
asr_interim_enabled=data.asrInterimEnabled,
|
||||
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
|
||||
interruption_sensitivity=data.interruptionSensitivity,
|
||||
config_mode=data.configMode,
|
||||
|
||||
@@ -291,6 +291,7 @@ class AssistantBase(BaseModel):
|
||||
speed: float = 1.0
|
||||
hotwords: List[str] = []
|
||||
tools: List[str] = []
|
||||
asrInterimEnabled: bool = False
|
||||
botCannotBeInterrupted: bool = False
|
||||
interruptionSensitivity: int = 500
|
||||
configMode: str = "platform"
|
||||
@@ -322,6 +323,7 @@ class AssistantUpdate(BaseModel):
|
||||
speed: Optional[float] = None
|
||||
hotwords: Optional[List[str]] = None
|
||||
tools: Optional[List[str]] = None
|
||||
asrInterimEnabled: Optional[bool] = None
|
||||
botCannotBeInterrupted: Optional[bool] = None
|
||||
interruptionSensitivity: Optional[int] = None
|
||||
configMode: Optional[str] = None
|
||||
|
||||
@@ -27,6 +27,7 @@ class TestAssistantAPI:
|
||||
assert data["voiceOutputEnabled"] is True
|
||||
assert data["firstTurnMode"] == "bot_first"
|
||||
assert data["generatedOpenerEnabled"] is False
|
||||
assert data["asrInterimEnabled"] is False
|
||||
assert data["botCannotBeInterrupted"] is False
|
||||
assert "id" in data
|
||||
assert data["callCount"] == 0
|
||||
@@ -37,6 +38,7 @@ class TestAssistantAPI:
|
||||
response = client.post("/api/assistants", json=data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Minimal Assistant"
|
||||
assert response.json()["asrInterimEnabled"] is False
|
||||
|
||||
def test_get_assistant_by_id(self, client, sample_assistant_data):
|
||||
"""Test getting a specific assistant by ID"""
|
||||
@@ -68,6 +70,7 @@ class TestAssistantAPI:
|
||||
"prompt": "You are an updated assistant.",
|
||||
"speed": 1.5,
|
||||
"voiceOutputEnabled": False,
|
||||
"asrInterimEnabled": True,
|
||||
"manualOpenerToolCalls": [
|
||||
{"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}}
|
||||
],
|
||||
@@ -79,6 +82,7 @@ class TestAssistantAPI:
|
||||
assert data["prompt"] == "You are an updated assistant."
|
||||
assert data["speed"] == 1.5
|
||||
assert data["voiceOutputEnabled"] is False
|
||||
assert data["asrInterimEnabled"] is True
|
||||
assert data["manualOpenerToolCalls"] == [
|
||||
{"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}}
|
||||
]
|
||||
@@ -213,6 +217,7 @@ class TestAssistantAPI:
|
||||
"prompt": "runtime prompt",
|
||||
"opener": "runtime opener",
|
||||
"manualOpenerToolCalls": [{"toolName": "text_msg_prompt", "arguments": {"msg": "欢迎"}}],
|
||||
"asrInterimEnabled": True,
|
||||
"speed": 1.1,
|
||||
})
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
@@ -232,6 +237,7 @@ class TestAssistantAPI:
|
||||
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
|
||||
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
|
||||
assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"]
|
||||
assert metadata["services"]["asr"]["enableInterim"] is True
|
||||
expected_tts_voice = f"{sample_voice_data['model']}:{sample_voice_data['voice_key']}"
|
||||
assert metadata["services"]["tts"]["voice"] == expected_tts_voice
|
||||
assert metadata["services"]["tts"]["baseUrl"] == sample_voice_data["base_url"]
|
||||
@@ -309,6 +315,7 @@ class TestAssistantAPI:
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
assert metadata["output"]["mode"] == "text"
|
||||
assert metadata["services"]["asr"]["enableInterim"] is False
|
||||
assert metadata["services"]["tts"]["enabled"] is False
|
||||
|
||||
def test_runtime_config_dashscope_voice_provider(self, client, sample_assistant_data):
|
||||
@@ -373,6 +380,17 @@ class TestAssistantAPI:
|
||||
asr = metadata["services"]["asr"]
|
||||
assert asr["provider"] == "dashscope"
|
||||
assert asr["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
||||
assert asr["enableInterim"] is False
|
||||
|
||||
def test_runtime_config_defaults_asr_interim_disabled_without_asr_model(self, client, sample_assistant_data):
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert assistant_resp.status_code == 200
|
||||
assistant_id = assistant_resp.json()["id"]
|
||||
|
||||
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
assert metadata["services"]["asr"]["enableInterim"] is False
|
||||
|
||||
def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data):
|
||||
sample_assistant_data.update({
|
||||
|
||||
Reference in New Issue
Block a user