"""Tests for ASR Model API endpoints""" import io import wave import pytest from unittest.mock import patch, MagicMock def _make_wav_bytes(sample_rate: int = 16000) -> bytes: with io.BytesIO() as buffer: with wave.open(buffer, "wb") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) wav_file.writeframes(b"\x00\x00" * sample_rate) return buffer.getvalue() class TestASRModelAPI: """Test cases for ASR Model endpoints""" def test_get_asr_models_empty(self, client): """Test getting ASR models when database is empty""" response = client.get("/api/asr") assert response.status_code == 200 data = response.json() assert "total" in data assert "list" in data assert data["total"] == 0 def test_create_asr_model(self, client, sample_asr_model_data): """Test creating a new ASR model""" response = client.post("/api/asr", json=sample_asr_model_data) assert response.status_code == 200 data = response.json() assert data["name"] == sample_asr_model_data["name"] assert data["vendor"] == sample_asr_model_data["vendor"] assert data["language"] == sample_asr_model_data["language"] assert "id" in data def test_create_asr_model_minimal(self, client): """Test creating an ASR model with minimal required data""" data = { "name": "Minimal ASR", "vendor": "Test", "language": "zh", "base_url": "https://api.test.com", "api_key": "test-key" } response = client.post("/api/asr", json=data) assert response.status_code == 200 assert response.json()["name"] == "Minimal ASR" def test_get_asr_model_by_id(self, client, sample_asr_model_data): """Test getting a specific ASR model by ID""" # Create first create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Get by ID response = client.get(f"/api/asr/{model_id}") assert response.status_code == 200 data = response.json() assert data["id"] == model_id assert data["name"] == sample_asr_model_data["name"] def test_get_asr_model_not_found(self, client): """Test getting a non-existent ASR model""" response = client.get("/api/asr/non-existent-id") assert response.status_code == 404 def test_update_asr_model(self, client, sample_asr_model_data): """Test updating an ASR model""" # Create first create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Update update_data = { "name": "Updated ASR Model", "language": "en", "enable_punctuation": False } response = client.put(f"/api/asr/{model_id}", json=update_data) assert response.status_code == 200 data = response.json() assert data["name"] == "Updated ASR Model" assert data["language"] == "en" assert data["enable_punctuation"] == False def test_update_asr_model_vendor(self, client, sample_asr_model_data): """Test updating ASR vendor metadata.""" create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] response = client.put( f"/api/asr/{model_id}", json={ "vendor": "DashScope", "model_name": "qwen3-asr-flash-realtime", "base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime", }, ) assert response.status_code == 200 data = response.json() assert data["vendor"] == "DashScope" assert data["model_name"] == "qwen3-asr-flash-realtime" def test_delete_asr_model(self, client, sample_asr_model_data): """Test deleting an ASR model""" # Create first create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Delete response = client.delete(f"/api/asr/{model_id}") assert response.status_code == 200 # Verify deleted get_response = client.get(f"/api/asr/{model_id}") assert get_response.status_code == 404 def test_list_asr_models_with_pagination(self, client, sample_asr_model_data): """Test listing ASR models with pagination""" # Create multiple models for i in range(3): data = sample_asr_model_data.copy() data["id"] = f"test-asr-{i}" data["name"] = f"ASR Model {i}" client.post("/api/asr", json=data) # Test pagination response = client.get("/api/asr?page=1&limit=2") assert response.status_code == 200 data = response.json() assert data["total"] == 3 assert len(data["list"]) == 2 def test_filter_asr_models_by_language(self, client, sample_asr_model_data): """Test filtering ASR models by language""" # Create models with different languages for i, lang in enumerate(["zh", "en", "Multi-lingual"]): data = sample_asr_model_data.copy() data["id"] = f"test-asr-{lang}" data["name"] = f"ASR {lang}" data["language"] = lang client.post("/api/asr", json=data) # Filter by language response = client.get("/api/asr?language=zh") assert response.status_code == 200 data = response.json() assert data["total"] >= 1 for model in data["list"]: assert model["language"] == "zh" def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data): """Test filtering ASR models by enabled status""" # Create enabled and disabled models data = sample_asr_model_data.copy() data["id"] = "test-asr-enabled" data["name"] = "Enabled ASR" data["enabled"] = True client.post("/api/asr", json=data) data["id"] = "test-asr-disabled" data["name"] = "Disabled ASR" data["enabled"] = False client.post("/api/asr", json=data) # Filter by enabled response = client.get("/api/asr?enabled=true") assert response.status_code == 200 data = response.json() for model in data["list"]: assert model["enabled"] == True def test_create_asr_model_with_hotwords(self, client): """Test creating an ASR model with hotwords""" data = { "id": "asr-hotwords", "name": "ASR with Hotwords", "vendor": "SiliconFlow", "language": "zh", "base_url": "https://api.siliconflow.cn/v1", "api_key": "test-key", "model_name": "paraformer-v2", "hotwords": ["你好", "查询", "帮助"], "enable_punctuation": True, "enable_normalization": True } response = client.post("/api/asr", json=data) assert response.status_code == 200 result = response.json() assert result["hotwords"] == ["你好", "查询", "帮助"] def test_create_asr_model_with_all_fields(self, client): """Test creating an ASR model with all fields""" data = { "id": "full-asr", "name": "Full ASR Model", "vendor": "SiliconFlow", "language": "zh", "base_url": "https://api.siliconflow.cn/v1", "api_key": "sk-test", "model_name": "paraformer-v2", "hotwords": ["测试"], "enable_punctuation": True, "enable_normalization": True, "enabled": True } response = client.post("/api/asr", json=data) assert response.status_code == 200 result = response.json() assert result["name"] == "Full ASR Model" assert result["enable_punctuation"] == True assert result["enable_normalization"] == True @patch('httpx.Client') def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data): """Test testing an ASR model with SiliconFlow vendor""" # Create model first sample_asr_model_data["vendor"] = "SiliconFlow" create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Mock the HTTP response mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "results": [{"transcript": "测试文本", "language": "zh"}] } mock_response.raise_for_status = MagicMock() mock_client.get.return_value = mock_response mock_client.__enter__ = MagicMock(return_value=mock_client) mock_client.__exit__ = MagicMock(return_value=False) with patch('app.routers.asr.httpx.Client', return_value=mock_client): response = client.post(f"/api/asr/{model_id}/test") assert response.status_code == 200 data = response.json() assert data["success"] == True @patch('httpx.Client') def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data): """Test testing an ASR model with OpenAI vendor""" # Create model with OpenAI vendor sample_asr_model_data["vendor"] = "OpenAI" sample_asr_model_data["id"] = "test-asr-openai" create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Mock the HTTP response mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"text": "Test transcript"} mock_response.raise_for_status = MagicMock() mock_client.get.return_value = mock_response mock_client.__enter__ = MagicMock(return_value=mock_client) mock_client.__exit__ = MagicMock(return_value=False) with patch('app.routers.asr.httpx.Client', return_value=mock_client): response = client.post(f"/api/asr/{model_id}/test") assert response.status_code == 200 def test_test_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch): """Test DashScope ASR connectivity probe.""" from app.routers import asr as asr_router sample_asr_model_data["vendor"] = "DashScope" sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime" create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] def fake_probe(**kwargs): assert kwargs["api_key"] == sample_asr_model_data["api_key"] assert kwargs["model"] == "qwen3-asr-flash-realtime" monkeypatch.setattr(asr_router, "_probe_dashscope_asr_connection", fake_probe) response = client.post(f"/api/asr/{model_id}/test") assert response.status_code == 200 data = response.json() assert data["success"] is True assert data["message"] == "DashScope realtime ASR connected" @patch('httpx.Client') def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data): """Test testing an ASR model with failed connection""" # Create model first create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Mock HTTP error mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 401 mock_response.text = "Unauthorized" mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized")) mock_client.get.return_value = mock_response mock_client.__enter__ = MagicMock(return_value=mock_client) mock_client.__exit__ = MagicMock(return_value=False) with patch('app.routers.asr.httpx.Client', return_value=mock_client): response = client.post(f"/api/asr/{model_id}/test") assert response.status_code == 200 data = response.json() assert data["success"] == False def test_different_asr_languages(self, client): """Test creating ASR models with different languages""" for lang in ["zh", "en", "Multi-lingual"]: data = { "id": f"asr-lang-{lang}", "name": f"ASR {lang}", "vendor": "SiliconFlow", "language": lang, "base_url": "https://api.siliconflow.cn/v1", "api_key": "test-key" } response = client.post("/api/asr", json=data) assert response.status_code == 200 assert response.json()["language"] == lang def test_different_asr_vendors(self, client): """Test creating ASR models with different vendors""" vendors = ["SiliconFlow", "OpenAI", "Azure", "DashScope"] for vendor in vendors: data = { "id": f"asr-vendor-{vendor.lower()}", "name": f"ASR {vendor}", "vendor": vendor, "language": "zh", "base_url": f"https://api.{vendor.lower()}.com/v1", "api_key": "test-key" } response = client.post("/api/asr", json=data) assert response.status_code == 200 assert response.json()["vendor"] == vendor def test_preview_asr_model_success(self, client, sample_asr_model_data, monkeypatch): """Test ASR preview endpoint with OpenAI-compatible transcriptions API.""" from app.routers import asr as asr_router create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] class DummyResponse: status_code = 200 def json(self): return {"text": "你好,这是测试转写", "language": "zh", "confidence": 0.98} @property def text(self): return '{"text":"ok"}' class DummyClient: def __init__(self, *args, **kwargs): pass def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def post(self, url, headers=None, data=None, files=None): assert url.endswith("/audio/transcriptions") assert headers["Authorization"] == f"Bearer {sample_asr_model_data['api_key']}" assert data["model"] == sample_asr_model_data["model_name"] assert files["file"][0] == "sample.wav" return DummyResponse() monkeypatch.setattr(asr_router.httpx, "Client", DummyClient) response = client.post( f"/api/asr/{model_id}/preview", files={"file": ("sample.wav", b"fake-wav-bytes", "audio/wav")}, ) assert response.status_code == 200 payload = response.json() assert payload["success"] is True assert payload["transcript"] == "你好,这是测试转写" assert payload["language"] == "zh" def test_preview_asr_model_reject_non_audio(self, client, sample_asr_model_data): """Test ASR preview endpoint rejects non-audio file.""" create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] response = client.post( f"/api/asr/{model_id}/preview", files={"file": ("sample.txt", b"text-data", "text/plain")}, ) assert response.status_code == 400 assert "Only audio files are supported" in response.text def test_preview_asr_model_dashscope(self, client, sample_asr_model_data, monkeypatch): """Test ASR preview endpoint with DashScope realtime helper.""" from app.routers import asr as asr_router sample_asr_model_data["vendor"] = "DashScope" sample_asr_model_data["base_url"] = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" sample_asr_model_data["model_name"] = "qwen3-asr-flash-realtime" create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] def fake_preview(**kwargs): assert kwargs["base_url"] == sample_asr_model_data["base_url"] assert kwargs["model"] == sample_asr_model_data["model_name"] return { "transcript": "你好,这是实时识别", "language": "zh", "confidence": None, } monkeypatch.setattr(asr_router, "_transcribe_dashscope_preview", fake_preview) response = client.post( f"/api/asr/{model_id}/preview", files={"file": ("sample.wav", _make_wav_bytes(), "audio/wav")}, ) assert response.status_code == 200 payload = response.json() assert payload["success"] is True assert payload["transcript"] == "你好,这是实时识别"