"""Tests for ASR Model API endpoints""" import pytest from unittest.mock import patch, MagicMock class TestASRModelAPI: """Test cases for ASR Model endpoints""" def test_get_asr_models_empty(self, client): """Test getting ASR models when database is empty""" response = client.get("/api/asr") assert response.status_code == 200 data = response.json() assert "total" in data assert "list" in data assert data["total"] == 0 def test_create_asr_model(self, client, sample_asr_model_data): """Test creating a new ASR model""" response = client.post("/api/asr", json=sample_asr_model_data) assert response.status_code == 200 data = response.json() assert data["name"] == sample_asr_model_data["name"] assert data["vendor"] == sample_asr_model_data["vendor"] assert data["language"] == sample_asr_model_data["language"] assert "id" in data def test_create_asr_model_minimal(self, client): """Test creating an ASR model with minimal required data""" data = { "name": "Minimal ASR", "vendor": "Test", "language": "zh", "base_url": "https://api.test.com", "api_key": "test-key" } response = client.post("/api/asr", json=data) assert response.status_code == 200 assert response.json()["name"] == "Minimal ASR" def test_get_asr_model_by_id(self, client, sample_asr_model_data): """Test getting a specific ASR model by ID""" # Create first create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Get by ID response = client.get(f"/api/asr/{model_id}") assert response.status_code == 200 data = response.json() assert data["id"] == model_id assert data["name"] == sample_asr_model_data["name"] def test_get_asr_model_not_found(self, client): """Test getting a non-existent ASR model""" response = client.get("/api/asr/non-existent-id") assert response.status_code == 404 def test_update_asr_model(self, client, sample_asr_model_data): """Test updating an ASR model""" # Create first create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Update update_data = { "name": "Updated ASR Model", "language": "en", "enable_punctuation": False } response = client.put(f"/api/asr/{model_id}", json=update_data) assert response.status_code == 200 data = response.json() assert data["name"] == "Updated ASR Model" assert data["language"] == "en" assert data["enable_punctuation"] == False def test_delete_asr_model(self, client, sample_asr_model_data): """Test deleting an ASR model""" # Create first create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Delete response = client.delete(f"/api/asr/{model_id}") assert response.status_code == 200 # Verify deleted get_response = client.get(f"/api/asr/{model_id}") assert get_response.status_code == 404 def test_list_asr_models_with_pagination(self, client, sample_asr_model_data): """Test listing ASR models with pagination""" # Create multiple models for i in range(3): data = sample_asr_model_data.copy() data["id"] = f"test-asr-{i}" data["name"] = f"ASR Model {i}" client.post("/api/asr", json=data) # Test pagination response = client.get("/api/asr?page=1&limit=2") assert response.status_code == 200 data = response.json() assert data["total"] == 3 assert len(data["list"]) == 2 def test_filter_asr_models_by_language(self, client, sample_asr_model_data): """Test filtering ASR models by language""" # Create models with different languages for i, lang in enumerate(["zh", "en", "Multi-lingual"]): data = sample_asr_model_data.copy() data["id"] = f"test-asr-{lang}" data["name"] = f"ASR {lang}" data["language"] = lang client.post("/api/asr", json=data) # Filter by language response = client.get("/api/asr?language=zh") assert response.status_code == 200 data = response.json() assert data["total"] >= 1 for model in data["list"]: assert model["language"] == "zh" def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data): """Test filtering ASR models by enabled status""" # Create enabled and disabled models data = sample_asr_model_data.copy() data["id"] = "test-asr-enabled" data["name"] = "Enabled ASR" data["enabled"] = True client.post("/api/asr", json=data) data["id"] = "test-asr-disabled" data["name"] = "Disabled ASR" data["enabled"] = False client.post("/api/asr", json=data) # Filter by enabled response = client.get("/api/asr?enabled=true") assert response.status_code == 200 data = response.json() for model in data["list"]: assert model["enabled"] == True def test_create_asr_model_with_hotwords(self, client): """Test creating an ASR model with hotwords""" data = { "id": "asr-hotwords", "name": "ASR with Hotwords", "vendor": "SiliconFlow", "language": "zh", "base_url": "https://api.siliconflow.cn/v1", "api_key": "test-key", "model_name": "paraformer-v2", "hotwords": ["你好", "查询", "帮助"], "enable_punctuation": True, "enable_normalization": True } response = client.post("/api/asr", json=data) assert response.status_code == 200 result = response.json() assert result["hotwords"] == ["你好", "查询", "帮助"] def test_create_asr_model_with_all_fields(self, client): """Test creating an ASR model with all fields""" data = { "id": "full-asr", "name": "Full ASR Model", "vendor": "SiliconFlow", "language": "zh", "base_url": "https://api.siliconflow.cn/v1", "api_key": "sk-test", "model_name": "paraformer-v2", "hotwords": ["测试"], "enable_punctuation": True, "enable_normalization": True, "enabled": True } response = client.post("/api/asr", json=data) assert response.status_code == 200 result = response.json() assert result["name"] == "Full ASR Model" assert result["enable_punctuation"] == True assert result["enable_normalization"] == True @patch('httpx.Client') def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data): """Test testing an ASR model with SiliconFlow vendor""" # Create model first sample_asr_model_data["vendor"] = "SiliconFlow" create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Mock the HTTP response mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "results": [{"transcript": "测试文本", "language": "zh"}] } mock_response.raise_for_status = MagicMock() mock_client.get.return_value = mock_response mock_client.__enter__ = MagicMock(return_value=mock_client) mock_client.__exit__ = MagicMock(return_value=False) with patch('app.routers.asr.httpx.Client', return_value=mock_client): response = client.post(f"/api/asr/{model_id}/test") assert response.status_code == 200 data = response.json() assert data["success"] == True @patch('httpx.Client') def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data): """Test testing an ASR model with OpenAI vendor""" # Create model with OpenAI vendor sample_asr_model_data["vendor"] = "OpenAI" sample_asr_model_data["id"] = "test-asr-openai" create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Mock the HTTP response mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"text": "Test transcript"} mock_response.raise_for_status = MagicMock() mock_client.get.return_value = mock_response mock_client.__enter__ = MagicMock(return_value=mock_client) mock_client.__exit__ = MagicMock(return_value=False) with patch('app.routers.asr.httpx.Client', return_value=mock_client): response = client.post(f"/api/asr/{model_id}/test") assert response.status_code == 200 @patch('httpx.Client') def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data): """Test testing an ASR model with failed connection""" # Create model first create_response = client.post("/api/asr", json=sample_asr_model_data) model_id = create_response.json()["id"] # Mock HTTP error mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 401 mock_response.text = "Unauthorized" mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized")) mock_client.get.return_value = mock_response mock_client.__enter__ = MagicMock(return_value=mock_client) mock_client.__exit__ = MagicMock(return_value=False) with patch('app.routers.asr.httpx.Client', return_value=mock_client): response = client.post(f"/api/asr/{model_id}/test") assert response.status_code == 200 data = response.json() assert data["success"] == False def test_different_asr_languages(self, client): """Test creating ASR models with different languages""" for lang in ["zh", "en", "Multi-lingual"]: data = { "id": f"asr-lang-{lang}", "name": f"ASR {lang}", "vendor": "SiliconFlow", "language": lang, "base_url": "https://api.siliconflow.cn/v1", "api_key": "test-key" } response = client.post("/api/asr", json=data) assert response.status_code == 200 assert response.json()["language"] == lang def test_different_asr_vendors(self, client): """Test creating ASR models with different vendors""" vendors = ["SiliconFlow", "OpenAI", "Azure"] for vendor in vendors: data = { "id": f"asr-vendor-{vendor.lower()}", "name": f"ASR {vendor}", "vendor": vendor, "language": "zh", "base_url": f"https://api.{vendor.lower()}.com/v1", "api_key": "test-key" } response = client.post("/api/asr", json=data) assert response.status_code == 200 assert response.json()["vendor"] == vendor