- Introduced DashScope as a new ASR model in the database initialization. - Updated ASRModel schema to include vendor information. - Enhanced ASR router to support DashScope-specific functionality, including connection testing and preview capabilities. - Modified frontend components to accommodate DashScope as a selectable vendor with appropriate default settings. - Added tests to validate DashScope ASR model creation, updates, and connectivity. - Updated backend API to handle DashScope-specific base URLs and vendor normalization.
431 lines
17 KiB
Python
431 lines
17 KiB
Python
"""Tests for ASR Model API endpoints"""
|
|
import io
|
|
import wave
|
|
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
def _make_wav_bytes(sample_rate: int = 16000) -> bytes:
|
|
with io.BytesIO() as buffer:
|
|
with wave.open(buffer, "wb") as wav_file:
|
|
wav_file.setnchannels(1)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setframerate(sample_rate)
|
|
wav_file.writeframes(b"\x00\x00" * sample_rate)
|
|
return buffer.getvalue()
|
|
|
|
|
|
class TestASRModelAPI:
|
|
"""Test cases for ASR Model endpoints"""
|
|
|
|
def test_get_asr_models_empty(self, client):
|
|
"""Test getting ASR models when database is empty"""
|
|
response = client.get("/api/asr")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "total" in data
|
|
assert "list" in data
|
|
assert data["total"] == 0
|
|
|
|
def test_create_asr_model(self, client, sample_asr_model_data):
|
|
"""Test creating a new ASR model"""
|
|
response = client.post("/api/asr", json=sample_asr_model_data)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == sample_asr_model_data["name"]
|
|
assert data["vendor"] == sample_asr_model_data["vendor"]
|
|
assert data["language"] == sample_asr_model_data["language"]
|
|
assert "id" in data
|
|
|
|
def test_create_asr_model_minimal(self, client):
|
|
"""Test creating an ASR model with minimal required data"""
|
|
data = {
|
|
"name": "Minimal ASR",
|
|
"vendor": "Test",
|
|
"language": "zh",
|
|
"base_url": "https://api.test.com",
|
|
"api_key": "test-key"
|
|
}
|
|
response = client.post("/api/asr", json=data)
|
|
assert response.status_code == 200
|
|
assert response.json()["name"] == "Minimal ASR"
|
|
|
|
def test_get_asr_model_by_id(self, client, sample_asr_model_data):
|
|
"""Test getting a specific ASR model by ID"""
|
|
# Create first
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Get by ID
|
|
response = client.get(f"/api/asr/{model_id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == model_id
|
|
assert data["name"] == sample_asr_model_data["name"]
|
|
|
|
def test_get_asr_model_not_found(self, client):
|
|
"""Test getting a non-existent ASR model"""
|
|
response = client.get("/api/asr/non-existent-id")
|
|
assert response.status_code == 404
|
|
|
|
def test_update_asr_model(self, client, sample_asr_model_data):
|
|
"""Test updating an ASR model"""
|
|
# Create first
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Update
|
|
update_data = {
|
|
"name": "Updated ASR Model",
|
|
"language": "en",
|
|
"enable_punctuation": False
|
|
}
|
|
response = client.put(f"/api/asr/{model_id}", json=update_data)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "Updated ASR Model"
|
|
assert data["language"] == "en"
|
|
assert data["enable_punctuation"] == False
|
|
|
|
def test_update_asr_model_vendor(self, client, sample_asr_model_data):
|
|
"""Test updating ASR vendor metadata."""
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
response = client.put(
|
|
f"/api/asr/{model_id}",
|
|
json={
|
|
"vendor": "DashScope",
|
|
"model_name": "qwen3-asr-flash-realtime",
|
|
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["vendor"] == "DashScope"
|
|
assert data["model_name"] == "qwen3-asr-flash-realtime"
|
|
|
|
def test_delete_asr_model(self, client, sample_asr_model_data):
|
|
"""Test deleting an ASR model"""
|
|
# Create first
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Delete
|
|
response = client.delete(f"/api/asr/{model_id}")
|
|
assert response.status_code == 200
|
|
|
|
# Verify deleted
|
|
get_response = client.get(f"/api/asr/{model_id}")
|
|
assert get_response.status_code == 404
|
|
|
|
def test_list_asr_models_with_pagination(self, client, sample_asr_model_data):
|
|
"""Test listing ASR models with pagination"""
|
|
# Create multiple models
|
|
for i in range(3):
|
|
data = sample_asr_model_data.copy()
|
|
data["id"] = f"test-asr-{i}"
|
|
data["name"] = f"ASR Model {i}"
|
|
client.post("/api/asr", json=data)
|
|
|
|
# Test pagination
|
|
response = client.get("/api/asr?page=1&limit=2")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 3
|
|
assert len(data["list"]) == 2
|
|
|
|
def test_filter_asr_models_by_language(self, client, sample_asr_model_data):
|
|
"""Test filtering ASR models by language"""
|
|
# Create models with different languages
|
|
for i, lang in enumerate(["zh", "en", "Multi-lingual"]):
|
|
data = sample_asr_model_data.copy()
|
|
data["id"] = f"test-asr-{lang}"
|
|
data["name"] = f"ASR {lang}"
|
|
data["language"] = lang
|
|
client.post("/api/asr", json=data)
|
|
|
|
# Filter by language
|
|
response = client.get("/api/asr?language=zh")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] >= 1
|
|
for model in data["list"]:
|
|
assert model["language"] == "zh"
|
|
|
|
def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data):
|
|
"""Test filtering ASR models by enabled status"""
|
|
# Create enabled and disabled models
|
|
data = sample_asr_model_data.copy()
|
|
data["id"] = "test-asr-enabled"
|
|
data["name"] = "Enabled ASR"
|
|
data["enabled"] = True
|
|
client.post("/api/asr", json=data)
|
|
|
|
data["id"] = "test-asr-disabled"
|
|
data["name"] = "Disabled ASR"
|
|
data["enabled"] = False
|
|
client.post("/api/asr", json=data)
|
|
|
|
# Filter by enabled
|
|
response = client.get("/api/asr?enabled=true")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
for model in data["list"]:
|
|
assert model["enabled"] == True
|
|
|
|
def test_create_asr_model_with_hotwords(self, client):
|
|
"""Test creating an ASR model with hotwords"""
|
|
data = {
|
|
"id": "asr-hotwords",
|
|
"name": "ASR with Hotwords",
|
|
"vendor": "SiliconFlow",
|
|
"language": "zh",
|
|
"base_url": "https://api.siliconflow.cn/v1",
|
|
"api_key": "test-key",
|
|
"model_name": "paraformer-v2",
|
|
"hotwords": ["你好", "查询", "帮助"],
|
|
"enable_punctuation": True,
|
|
"enable_normalization": True
|
|
}
|
|
response = client.post("/api/asr", json=data)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
assert result["hotwords"] == ["你好", "查询", "帮助"]
|
|
|
|
def test_create_asr_model_with_all_fields(self, client):
|
|
"""Test creating an ASR model with all fields"""
|
|
data = {
|
|
"id": "full-asr",
|
|
"name": "Full ASR Model",
|
|
"vendor": "SiliconFlow",
|
|
"language": "zh",
|
|
"base_url": "https://api.siliconflow.cn/v1",
|
|
"api_key": "sk-test",
|
|
"model_name": "paraformer-v2",
|
|
"hotwords": ["测试"],
|
|
"enable_punctuation": True,
|
|
"enable_normalization": True,
|
|
"enabled": True
|
|
}
|
|
response = client.post("/api/asr", json=data)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
assert result["name"] == "Full ASR Model"
|
|
assert result["enable_punctuation"] == True
|
|
assert result["enable_normalization"] == True
|
|
|
|
@patch('httpx.Client')
|
|
def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data):
|
|
"""Test testing an ASR model with SiliconFlow vendor"""
|
|
# Create model first
|
|
sample_asr_model_data["vendor"] = "SiliconFlow"
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Mock the HTTP response
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"results": [{"transcript": "测试文本", "language": "zh"}]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
|
response = client.post(f"/api/asr/{model_id}/test")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] == True
|
|
|
|
@patch('httpx.Client')
|
|
def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data):
|
|
"""Test testing an ASR model with OpenAI vendor"""
|
|
# Create model with OpenAI vendor
|
|
sample_asr_model_data["vendor"] = "OpenAI"
|
|
sample_asr_model_data["id"] = "test-asr-openai"
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Mock the HTTP response
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"text": "Test transcript"}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
|
response = client.post(f"/api/asr/{model_id}/test")
|
|
assert response.status_code == 200
|
|
|
|
def test_test_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch):
|
|
"""Test DashScope ASR connectivity probe."""
|
|
from app.routers import asr as asr_router
|
|
|
|
sample_asr_model_data["vendor"] = "DashScope"
|
|
sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
|
sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime"
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
def fake_probe(**kwargs):
|
|
assert kwargs["api_key"] == sample_asr_model_data["api_key"]
|
|
assert kwargs["model"] == "qwen3-asr-flash-realtime"
|
|
|
|
monkeypatch.setattr(asr_router, "_probe_dashscope_asr_connection", fake_probe)
|
|
|
|
response = client.post(f"/api/asr/{model_id}/test")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["message"] == "DashScope realtime ASR connected"
|
|
|
|
@patch('httpx.Client')
|
|
def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data):
|
|
"""Test testing an ASR model with failed connection"""
|
|
# Create model first
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
# Mock HTTP error
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 401
|
|
mock_response.text = "Unauthorized"
|
|
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
|
mock_client.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
|
|
response = client.post(f"/api/asr/{model_id}/test")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] == False
|
|
|
|
def test_different_asr_languages(self, client):
|
|
"""Test creating ASR models with different languages"""
|
|
for lang in ["zh", "en", "Multi-lingual"]:
|
|
data = {
|
|
"id": f"asr-lang-{lang}",
|
|
"name": f"ASR {lang}",
|
|
"vendor": "SiliconFlow",
|
|
"language": lang,
|
|
"base_url": "https://api.siliconflow.cn/v1",
|
|
"api_key": "test-key"
|
|
}
|
|
response = client.post("/api/asr", json=data)
|
|
assert response.status_code == 200
|
|
assert response.json()["language"] == lang
|
|
|
|
def test_different_asr_vendors(self, client):
|
|
"""Test creating ASR models with different vendors"""
|
|
vendors = ["SiliconFlow", "OpenAI", "Azure", "DashScope"]
|
|
for vendor in vendors:
|
|
data = {
|
|
"id": f"asr-vendor-{vendor.lower()}",
|
|
"name": f"ASR {vendor}",
|
|
"vendor": vendor,
|
|
"language": "zh",
|
|
"base_url": f"https://api.{vendor.lower()}.com/v1",
|
|
"api_key": "test-key"
|
|
}
|
|
response = client.post("/api/asr", json=data)
|
|
assert response.status_code == 200
|
|
assert response.json()["vendor"] == vendor
|
|
|
|
def test_preview_asr_model_success(self, client, sample_asr_model_data, monkeypatch):
|
|
"""Test ASR preview endpoint with OpenAI-compatible transcriptions API."""
|
|
from app.routers import asr as asr_router
|
|
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
class DummyResponse:
|
|
status_code = 200
|
|
|
|
def json(self):
|
|
return {"text": "你好,这是测试转写", "language": "zh", "confidence": 0.98}
|
|
|
|
@property
|
|
def text(self):
|
|
return '{"text":"ok"}'
|
|
|
|
class DummyClient:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def post(self, url, headers=None, data=None, files=None):
|
|
assert url.endswith("/audio/transcriptions")
|
|
assert headers["Authorization"] == f"Bearer {sample_asr_model_data['api_key']}"
|
|
assert data["model"] == sample_asr_model_data["model_name"]
|
|
assert files["file"][0] == "sample.wav"
|
|
return DummyResponse()
|
|
|
|
monkeypatch.setattr(asr_router.httpx, "Client", DummyClient)
|
|
|
|
response = client.post(
|
|
f"/api/asr/{model_id}/preview",
|
|
files={"file": ("sample.wav", b"fake-wav-bytes", "audio/wav")},
|
|
)
|
|
assert response.status_code == 200
|
|
payload = response.json()
|
|
assert payload["success"] is True
|
|
assert payload["transcript"] == "你好,这是测试转写"
|
|
assert payload["language"] == "zh"
|
|
|
|
def test_preview_asr_model_reject_non_audio(self, client, sample_asr_model_data):
|
|
"""Test ASR preview endpoint rejects non-audio file."""
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
response = client.post(
|
|
f"/api/asr/{model_id}/preview",
|
|
files={"file": ("sample.txt", b"text-data", "text/plain")},
|
|
)
|
|
assert response.status_code == 400
|
|
assert "Only audio files are supported" in response.text
|
|
|
|
def test_preview_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch):
|
|
"""Test ASR preview endpoint with DashScope realtime helper."""
|
|
from app.routers import asr as asr_router
|
|
|
|
sample_asr_model_data["vendor"] = "DashScope"
|
|
sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
|
sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime"
|
|
create_response = client.post("/api/asr", json=sample_asr_model_data)
|
|
model_id = create_response.json()["id"]
|
|
|
|
def fake_preview(**kwargs):
|
|
assert kwargs["base_url"] == sample_asr_model_data["base_url"]
|
|
assert kwargs["model"] == sample_asr_model_data["model_name"]
|
|
return {
|
|
"transcript": "你好,这是实时识别",
|
|
"language": "zh",
|
|
"confidence": None,
|
|
}
|
|
|
|
monkeypatch.setattr(asr_router, "_transcribe_dashscope_preview", fake_preview)
|
|
|
|
response = client.post(
|
|
f"/api/asr/{model_id}/preview",
|
|
files={"file": ("sample.wav", _make_wav_bytes(), "audio/wav")},
|
|
)
|
|
assert response.status_code == 200
|
|
payload = response.json()
|
|
assert payload["success"] is True
|
|
assert payload["transcript"] == "你好,这是实时识别"
|