Files
AI-VideoAssistant/api/tests/test_assistants.py
Xin Wang da38157638 Add ASR interim results support in Assistant model and API
- Introduced `asr_interim_enabled` field in the Assistant model to control interim ASR results.
- Updated AssistantBase and AssistantUpdate schemas to include the new field.
- Modified the database schema to add the `asr_interim_enabled` column.
- Enhanced runtime metadata to reflect interim ASR settings.
- Updated API endpoints and tests to validate the new functionality.
- Adjusted documentation to include details about interim ASR results configuration.
2026-03-06 12:58:54 +08:00

422 lines
20 KiB
Python

"""Tests for Assistant API endpoints"""
import pytest
import uuid
class TestAssistantAPI:
"""Test cases for Assistant endpoints"""
def test_get_assistants_empty(self, client):
"""Test getting assistants when database is empty"""
response = client.get("/api/assistants")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
def test_create_assistant(self, client, sample_assistant_data):
"""Test creating a new assistant"""
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == sample_assistant_data["name"]
assert data["opener"] == sample_assistant_data["opener"]
assert data["manualOpenerToolCalls"] == []
assert data["prompt"] == sample_assistant_data["prompt"]
assert data["language"] == sample_assistant_data["language"]
assert data["voiceOutputEnabled"] is True
assert data["firstTurnMode"] == "bot_first"
assert data["generatedOpenerEnabled"] is False
assert data["asrInterimEnabled"] is False
assert data["botCannotBeInterrupted"] is False
assert "id" in data
assert data["callCount"] == 0
def test_create_assistant_minimal(self, client):
"""Test creating an assistant with minimal required data"""
data = {"name": "Minimal Assistant"}
response = client.post("/api/assistants", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Minimal Assistant"
assert response.json()["asrInterimEnabled"] is False
def test_get_assistant_by_id(self, client, sample_assistant_data):
"""Test getting a specific assistant by ID"""
# Create first
create_response = client.post("/api/assistants", json=sample_assistant_data)
assistant_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/assistants/{assistant_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == sample_assistant_data["name"]
def test_get_assistant_not_found(self, client):
"""Test getting a non-existent assistant"""
response = client.get("/api/assistants/non-existent-id")
assert response.status_code == 404
def test_update_assistant(self, client, sample_assistant_data):
"""Test updating an assistant"""
# Create first
create_response = client.post("/api/assistants", json=sample_assistant_data)
assistant_id = create_response.json()["id"]
# Update
update_data = {
"name": "Updated Assistant",
"prompt": "You are an updated assistant.",
"speed": 1.5,
"voiceOutputEnabled": False,
"asrInterimEnabled": True,
"manualOpenerToolCalls": [
{"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}}
],
}
response = client.put(f"/api/assistants/{assistant_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated Assistant"
assert data["prompt"] == "You are an updated assistant."
assert data["speed"] == 1.5
assert data["voiceOutputEnabled"] is False
assert data["asrInterimEnabled"] is True
assert data["manualOpenerToolCalls"] == [
{"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}}
]
def test_delete_assistant(self, client, sample_assistant_data):
"""Test deleting an assistant"""
# Create first
create_response = client.post("/api/assistants", json=sample_assistant_data)
assistant_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/assistants/{assistant_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/assistants/{assistant_id}")
assert get_response.status_code == 404
def test_list_assistants_with_pagination(self, client, sample_assistant_data):
"""Test listing assistants with pagination"""
# Create multiple assistants
for i in range(3):
data = sample_assistant_data.copy()
data["name"] = f"Assistant {i}"
client.post("/api/assistants", json=data)
# Test pagination
response = client.get("/api/assistants?page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["list"]) == 2
def test_create_assistant_with_voice(self, client, sample_assistant_data, sample_voice_data):
"""Test creating an assistant with a voice reference"""
# Create a voice first
voice_response = client.post("/api/voices", json=sample_voice_data)
voice_id = voice_response.json()["id"]
# Create assistant with voice
sample_assistant_data["voice"] = voice_id
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["voice"] == voice_id
def test_create_assistant_with_knowledge_base(self, client, sample_assistant_data):
"""Test creating an assistant with knowledge base reference"""
# Note: This test assumes knowledge base doesn't exist
sample_assistant_data["knowledgeBaseId"] = "non-existent-kb"
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["knowledgeBaseId"] == "non-existent-kb"
assistant_id = response.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["knowledgeBaseId"] == "non-existent-kb"
assert metadata["knowledge"]["enabled"] is True
assert metadata["knowledge"]["kbId"] == "non-existent-kb"
def test_assistant_with_model_references(self, client, sample_assistant_data):
"""Test creating assistant with model references"""
sample_assistant_data.update({
"llmModelId": "llm-001",
"asrModelId": "asr-001",
"embeddingModelId": "emb-001",
"rerankModelId": "rerank-001"
})
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
data = response.json()
assert data["llmModelId"] == "llm-001"
assert data["asrModelId"] == "asr-001"
assert data["embeddingModelId"] == "emb-001"
assert data["rerankModelId"] == "rerank-001"
def test_assistant_with_tools(self, client, sample_assistant_data):
"""Test creating assistant with tools"""
sample_assistant_data["tools"] = ["weather", "calculator", "search"]
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["tools"] == ["weather", "calculator", "search"]
def test_assistant_with_hotwords(self, client, sample_assistant_data):
"""Test creating assistant with hotwords"""
sample_assistant_data["hotwords"] = ["hello", "help", "stop"]
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["hotwords"] == ["hello", "help", "stop"]
def test_different_config_modes(self, client, sample_assistant_data):
"""Test creating assistants with different config modes"""
for mode in ["platform", "dify", "fastgpt", "none"]:
sample_assistant_data["name"] = f"Assistant {mode}"
sample_assistant_data["configMode"] = mode
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["configMode"] == mode
def test_different_languages(self, client, sample_assistant_data):
"""Test creating assistants with different languages"""
for lang in ["zh", "en", "ja", "ko"]:
sample_assistant_data["name"] = f"Assistant {lang}"
sample_assistant_data["language"] = lang
response = client.post("/api/assistants", json=sample_assistant_data)
assert response.status_code == 200
assert response.json()["language"] == lang
def test_get_runtime_config(self, client, sample_assistant_data, sample_llm_model_data, sample_asr_model_data, sample_voice_data):
"""Test resolved runtime config endpoint for WS session.start metadata."""
sample_asr_model_data["vendor"] = "OpenAI Compatible"
llm_resp = client.post("/api/llm", json=sample_llm_model_data)
assert llm_resp.status_code == 200
llm_id = llm_resp.json()["id"]
asr_resp = client.post("/api/asr", json=sample_asr_model_data)
assert asr_resp.status_code == 200
asr_id = asr_resp.json()["id"]
sample_voice_data["vendor"] = "OpenAI Compatible"
sample_voice_data["base_url"] = "https://tts.example.com/v1/audio/speech"
sample_voice_data["api_key"] = "test-voice-key"
voice_resp = client.post("/api/voices", json=sample_voice_data)
assert voice_resp.status_code == 200
voice_id = voice_resp.json()["id"]
sample_assistant_data.update({
"llmModelId": llm_id,
"asrModelId": asr_id,
"voice": voice_id,
"prompt": "runtime prompt",
"opener": "runtime opener",
"manualOpenerToolCalls": [{"toolName": "text_msg_prompt", "arguments": {"msg": "欢迎"}}],
"asrInterimEnabled": True,
"speed": 1.1,
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
payload = runtime_resp.json()
assert payload["assistantId"] == assistant_id
metadata = payload["sessionStartMetadata"]
assert metadata["systemPrompt"].startswith("runtime prompt")
assert "Tool usage policy:" in metadata["systemPrompt"]
assert metadata["greeting"] == "runtime opener"
assert metadata["manualOpenerToolCalls"] == [{"toolName": "text_msg_prompt", "arguments": {"msg": "欢迎"}}]
assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"]
assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"]
assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"]
assert metadata["services"]["asr"]["enableInterim"] is True
expected_tts_voice = f"{sample_voice_data['model']}:{sample_voice_data['voice_key']}"
assert metadata["services"]["tts"]["voice"] == expected_tts_voice
assert metadata["services"]["tts"]["baseUrl"] == sample_voice_data["base_url"]
def test_get_engine_config_endpoint(self, client, sample_assistant_data):
"""Test canonical assistant config endpoint consumed by engine backend adapter."""
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
config_resp = client.get(f"/api/assistants/{assistant_id}/config")
assert config_resp.status_code == 200
payload = config_resp.json()
assert payload["assistantId"] == assistant_id
assert payload["assistant"]["assistantId"] == assistant_id
assert payload["assistant"]["configVersionId"].startswith(f"asst_{assistant_id}_")
assert payload["assistant"]["systemPrompt"].startswith(sample_assistant_data["prompt"])
assert "Tool usage policy:" in payload["assistant"]["systemPrompt"]
assert payload["sessionStartMetadata"]["systemPrompt"].startswith(sample_assistant_data["prompt"])
assert "Tool usage policy:" in payload["sessionStartMetadata"]["systemPrompt"]
assert payload["sessionStartMetadata"]["history"]["assistantId"] == assistant_id
def test_runtime_config_resolves_selected_tools_into_runtime_definitions(self, client, sample_assistant_data):
sample_assistant_data["tools"] = ["increase_volume", "calculator"]
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
tools = metadata["tools"]
assert isinstance(tools, list)
assert len(tools) == 2
by_name = {item["function"]["name"]: item for item in tools}
assert by_name["increase_volume"]["executor"] == "client"
assert by_name["increase_volume"]["defaultArgs"]["step"] == 1
assert by_name["calculator"]["executor"] == "server"
assert by_name["calculator"]["function"]["parameters"]["type"] == "object"
assert "expression" in by_name["calculator"]["function"]["parameters"]["properties"]
def test_runtime_config_normalizes_legacy_voice_message_prompt_tool_id(self, client, sample_assistant_data):
sample_assistant_data["tools"] = ["voice_message_prompt"]
sample_assistant_data["manualOpenerToolCalls"] = [
{"toolName": "voice_message_prompt", "arguments": {"msg": "您好"}}
]
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_payload = assistant_resp.json()
assistant_id = assistant_payload["id"]
assert assistant_payload["tools"] == ["voice_msg_prompt"]
assert assistant_payload["manualOpenerToolCalls"] == [
{"toolName": "voice_msg_prompt", "arguments": {"msg": "您好"}}
]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
tools = metadata["tools"]
by_name = {item["function"]["name"]: item for item in tools}
assert "voice_msg_prompt" in by_name
assert metadata["manualOpenerToolCalls"] == [
{"toolName": "voice_msg_prompt", "arguments": {"msg": "您好"}}
]
def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data):
sample_assistant_data["voiceOutputEnabled"] = False
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["output"]["mode"] == "text"
assert metadata["services"]["asr"]["enableInterim"] is False
assert metadata["services"]["tts"]["enabled"] is False
def test_runtime_config_dashscope_voice_provider(self, client, sample_assistant_data):
"""DashScope voices should map to dashscope tts provider in runtime metadata."""
voice_resp = client.post("/api/voices", json={
"name": "DashScope Cherry",
"vendor": "DashScope",
"gender": "Female",
"language": "zh",
"description": "dashscope voice",
"api_key": "dashscope-key",
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
})
assert voice_resp.status_code == 200
voice_payload = voice_resp.json()
sample_assistant_data.update({
"voice": voice_payload["id"],
"voiceOutputEnabled": True,
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
tts = metadata["services"]["tts"]
assert tts["provider"] == "dashscope"
assert tts["voice"] == "Cherry"
assert tts["model"] == "qwen3-tts-flash-realtime"
assert tts["apiKey"] == "dashscope-key"
assert tts["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
def test_runtime_config_dashscope_asr_provider(self, client, sample_assistant_data):
"""DashScope ASR models should map to dashscope asr provider in runtime metadata."""
asr_resp = client.post("/api/asr", json={
"name": "DashScope Realtime ASR",
"vendor": "DashScope",
"language": "zh",
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
"api_key": "dashscope-asr-key",
"model_name": "qwen3-asr-flash-realtime",
"hotwords": [],
"enable_punctuation": True,
"enable_normalization": True,
"enabled": True,
})
assert asr_resp.status_code == 200
asr_payload = asr_resp.json()
sample_assistant_data.update({
"asrModelId": asr_payload["id"],
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
asr = metadata["services"]["asr"]
assert asr["provider"] == "dashscope"
assert asr["baseUrl"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
assert asr["enableInterim"] is False
def test_runtime_config_defaults_asr_interim_disabled_without_asr_model(self, client, sample_assistant_data):
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["services"]["asr"]["enableInterim"] is False
def test_assistant_interrupt_and_generated_opener_flags(self, client, sample_assistant_data):
sample_assistant_data.update({
"firstTurnMode": "user_first",
"generatedOpenerEnabled": True,
"botCannotBeInterrupted": True,
"interruptionSensitivity": 900,
})
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
assert assistant_resp.status_code == 200
assistant_id = assistant_resp.json()["id"]
get_resp = client.get(f"/api/assistants/{assistant_id}")
assert get_resp.status_code == 200
payload = get_resp.json()
assert payload["firstTurnMode"] == "user_first"
assert payload["generatedOpenerEnabled"] is True
assert payload["botCannotBeInterrupted"] is True
assert payload["interruptionSensitivity"] == 900
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["firstTurnMode"] == "user_first"
assert metadata["generatedOpenerEnabled"] is True
assert metadata["greeting"] == ""
assert metadata["bargeIn"]["enabled"] is False
assert metadata["bargeIn"]["minDurationMs"] == 900