Files
AI-VideoAssistant/api/tests/test_asr.py
2026-02-08 15:52:16 +08:00

290 lines
11 KiB
Python

"""Tests for ASR Model API endpoints"""
import pytest
from unittest.mock import patch, MagicMock
class TestASRModelAPI:
"""Test cases for ASR Model endpoints"""
def test_get_asr_models_empty(self, client):
"""Test getting ASR models when database is empty"""
response = client.get("/api/asr")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
assert data["total"] == 0
def test_create_asr_model(self, client, sample_asr_model_data):
"""Test creating a new ASR model"""
response = client.post("/api/asr", json=sample_asr_model_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == sample_asr_model_data["name"]
assert data["vendor"] == sample_asr_model_data["vendor"]
assert data["language"] == sample_asr_model_data["language"]
assert "id" in data
def test_create_asr_model_minimal(self, client):
"""Test creating an ASR model with minimal required data"""
data = {
"name": "Minimal ASR",
"vendor": "Test",
"language": "zh",
"base_url": "https://api.test.com",
"api_key": "test-key"
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
assert response.json()["name"] == "Minimal ASR"
def test_get_asr_model_by_id(self, client, sample_asr_model_data):
"""Test getting a specific ASR model by ID"""
# Create first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/asr/{model_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == model_id
assert data["name"] == sample_asr_model_data["name"]
def test_get_asr_model_not_found(self, client):
"""Test getting a non-existent ASR model"""
response = client.get("/api/asr/non-existent-id")
assert response.status_code == 404
def test_update_asr_model(self, client, sample_asr_model_data):
"""Test updating an ASR model"""
# Create first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Update
update_data = {
"name": "Updated ASR Model",
"language": "en",
"enable_punctuation": False
}
response = client.put(f"/api/asr/{model_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated ASR Model"
assert data["language"] == "en"
assert data["enable_punctuation"] == False
def test_delete_asr_model(self, client, sample_asr_model_data):
"""Test deleting an ASR model"""
# Create first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/asr/{model_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/asr/{model_id}")
assert get_response.status_code == 404
def test_list_asr_models_with_pagination(self, client, sample_asr_model_data):
"""Test listing ASR models with pagination"""
# Create multiple models
for i in range(3):
data = sample_asr_model_data.copy()
data["id"] = f"test-asr-{i}"
data["name"] = f"ASR Model {i}"
client.post("/api/asr", json=data)
# Test pagination
response = client.get("/api/asr?page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["list"]) == 2
def test_filter_asr_models_by_language(self, client, sample_asr_model_data):
"""Test filtering ASR models by language"""
# Create models with different languages
for i, lang in enumerate(["zh", "en", "Multi-lingual"]):
data = sample_asr_model_data.copy()
data["id"] = f"test-asr-{lang}"
data["name"] = f"ASR {lang}"
data["language"] = lang
client.post("/api/asr", json=data)
# Filter by language
response = client.get("/api/asr?language=zh")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
for model in data["list"]:
assert model["language"] == "zh"
def test_filter_asr_models_by_enabled(self, client, sample_asr_model_data):
"""Test filtering ASR models by enabled status"""
# Create enabled and disabled models
data = sample_asr_model_data.copy()
data["id"] = "test-asr-enabled"
data["name"] = "Enabled ASR"
data["enabled"] = True
client.post("/api/asr", json=data)
data["id"] = "test-asr-disabled"
data["name"] = "Disabled ASR"
data["enabled"] = False
client.post("/api/asr", json=data)
# Filter by enabled
response = client.get("/api/asr?enabled=true")
assert response.status_code == 200
data = response.json()
for model in data["list"]:
assert model["enabled"] == True
def test_create_asr_model_with_hotwords(self, client):
"""Test creating an ASR model with hotwords"""
data = {
"id": "asr-hotwords",
"name": "ASR with Hotwords",
"vendor": "SiliconFlow",
"language": "zh",
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "test-key",
"model_name": "paraformer-v2",
"hotwords": ["你好", "查询", "帮助"],
"enable_punctuation": True,
"enable_normalization": True
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
result = response.json()
assert result["hotwords"] == ["你好", "查询", "帮助"]
def test_create_asr_model_with_all_fields(self, client):
"""Test creating an ASR model with all fields"""
data = {
"id": "full-asr",
"name": "Full ASR Model",
"vendor": "SiliconFlow",
"language": "zh",
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "sk-test",
"model_name": "paraformer-v2",
"hotwords": ["测试"],
"enable_punctuation": True,
"enable_normalization": True,
"enabled": True
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
result = response.json()
assert result["name"] == "Full ASR Model"
assert result["enable_punctuation"] == True
assert result["enable_normalization"] == True
@patch('httpx.Client')
def test_test_asr_model_siliconflow(self, mock_client_class, client, sample_asr_model_data):
"""Test testing an ASR model with SiliconFlow vendor"""
# Create model first
sample_asr_model_data["vendor"] = "SiliconFlow"
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"results": [{"transcript": "测试文本", "language": "zh"}]
}
mock_response.raise_for_status = MagicMock()
mock_client.get.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
response = client.post(f"/api/asr/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] == True
@patch('httpx.Client')
def test_test_asr_model_openai(self, mock_client_class, client, sample_asr_model_data):
"""Test testing an ASR model with OpenAI vendor"""
# Create model with OpenAI vendor
sample_asr_model_data["vendor"] = "OpenAI"
sample_asr_model_data["id"] = "test-asr-openai"
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Mock the HTTP response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"text": "Test transcript"}
mock_response.raise_for_status = MagicMock()
mock_client.get.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
response = client.post(f"/api/asr/{model_id}/test")
assert response.status_code == 200
@patch('httpx.Client')
def test_test_asr_model_failure(self, mock_client_class, client, sample_asr_model_data):
"""Test testing an ASR model with failed connection"""
# Create model first
create_response = client.post("/api/asr", json=sample_asr_model_data)
model_id = create_response.json()["id"]
# Mock HTTP error
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
mock_response.raise_for_status = MagicMock(side_effect=Exception("401 Unauthorized"))
mock_client.get.return_value = mock_response
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
with patch('app.routers.asr.httpx.Client', return_value=mock_client):
response = client.post(f"/api/asr/{model_id}/test")
assert response.status_code == 200
data = response.json()
assert data["success"] == False
def test_different_asr_languages(self, client):
"""Test creating ASR models with different languages"""
for lang in ["zh", "en", "Multi-lingual"]:
data = {
"id": f"asr-lang-{lang}",
"name": f"ASR {lang}",
"vendor": "SiliconFlow",
"language": lang,
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "test-key"
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
assert response.json()["language"] == lang
def test_different_asr_vendors(self, client):
"""Test creating ASR models with different vendors"""
vendors = ["SiliconFlow", "OpenAI", "Azure"]
for vendor in vendors:
data = {
"id": f"asr-vendor-{vendor.lower()}",
"name": f"ASR {vendor}",
"vendor": vendor,
"language": "zh",
"base_url": f"https://api.{vendor.lower()}.com/v1",
"api_key": "test-key"
}
response = client.post("/api/asr", json=data)
assert response.status_code == 200
assert response.json()["vendor"] == vendor