241 lines
8.5 KiB
Python
241 lines
8.5 KiB
Python
"""Tests for Voice API endpoints"""
|
|
import base64
|
|
import pytest
|
|
|
|
|
|
class TestVoiceAPI:
|
|
"""Test cases for Voice endpoints"""
|
|
|
|
def test_get_voices_empty(self, client):
|
|
"""Test getting voices when database is empty"""
|
|
response = client.get("/api/voices")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "total" in data
|
|
assert "list" in data
|
|
|
|
def test_create_voice(self, client, sample_voice_data):
|
|
"""Test creating a new voice"""
|
|
response = client.post("/api/voices", json=sample_voice_data)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == sample_voice_data["name"]
|
|
assert data["vendor"] == sample_voice_data["vendor"]
|
|
assert data["gender"] == sample_voice_data["gender"]
|
|
assert data["language"] == sample_voice_data["language"]
|
|
assert "id" in data
|
|
|
|
def test_create_voice_minimal(self, client):
|
|
"""Test creating a voice with minimal data"""
|
|
data = {
|
|
"name": "Minimal Voice",
|
|
"vendor": "Test",
|
|
"gender": "Male",
|
|
"language": "en",
|
|
"description": ""
|
|
}
|
|
response = client.post("/api/voices", json=data)
|
|
assert response.status_code == 200
|
|
|
|
def test_get_voice_by_id(self, client, sample_voice_data):
|
|
"""Test getting a specific voice by ID"""
|
|
# Create first
|
|
create_response = client.post("/api/voices", json=sample_voice_data)
|
|
voice_id = create_response.json()["id"]
|
|
|
|
# Get by ID
|
|
response = client.get(f"/api/voices/{voice_id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == voice_id
|
|
assert data["name"] == sample_voice_data["name"]
|
|
|
|
def test_get_voice_not_found(self, client):
|
|
"""Test getting a non-existent voice"""
|
|
response = client.get("/api/voices/non-existent-id")
|
|
assert response.status_code == 404
|
|
|
|
def test_update_voice(self, client, sample_voice_data):
|
|
"""Test updating a voice"""
|
|
# Create first
|
|
create_response = client.post("/api/voices", json=sample_voice_data)
|
|
voice_id = create_response.json()["id"]
|
|
|
|
# Update
|
|
update_data = {"name": "Updated Voice", "speed": 1.5}
|
|
response = client.put(f"/api/voices/{voice_id}", json=update_data)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "Updated Voice"
|
|
assert data["speed"] == 1.5
|
|
|
|
def test_delete_voice(self, client, sample_voice_data):
|
|
"""Test deleting a voice"""
|
|
# Create first
|
|
create_response = client.post("/api/voices", json=sample_voice_data)
|
|
voice_id = create_response.json()["id"]
|
|
|
|
# Delete
|
|
response = client.delete(f"/api/voices/{voice_id}")
|
|
assert response.status_code == 200
|
|
|
|
# Verify deleted
|
|
get_response = client.get(f"/api/voices/{voice_id}")
|
|
assert get_response.status_code == 404
|
|
|
|
def test_list_voices_with_pagination(self, client, sample_voice_data):
|
|
"""Test listing voices with pagination"""
|
|
# Create multiple voices
|
|
for i in range(3):
|
|
data = sample_voice_data.copy()
|
|
data["name"] = f"Voice {i}"
|
|
client.post("/api/voices", json=data)
|
|
|
|
# Test pagination
|
|
response = client.get("/api/voices?page=1&limit=2")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 3
|
|
assert len(data["list"]) == 2
|
|
|
|
def test_filter_voices_by_vendor(self, client, sample_voice_data):
|
|
"""Test filtering voices by vendor"""
|
|
# Create voice with specific vendor
|
|
sample_voice_data["vendor"] = "FilterTestVendor"
|
|
client.post("/api/voices", json=sample_voice_data)
|
|
|
|
response = client.get("/api/voices?vendor=FilterTestVendor")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
for voice in data["list"]:
|
|
assert voice["vendor"] == "FilterTestVendor"
|
|
|
|
def test_filter_voices_by_language(self, client, sample_voice_data):
|
|
"""Test filtering voices by language"""
|
|
sample_voice_data["language"] = "en"
|
|
client.post("/api/voices", json=sample_voice_data)
|
|
|
|
response = client.get("/api/voices?language=en")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
for voice in data["list"]:
|
|
assert voice["language"] == "en"
|
|
|
|
def test_filter_voices_by_gender(self, client, sample_voice_data):
|
|
"""Test filtering voices by gender"""
|
|
sample_voice_data["gender"] = "Female"
|
|
client.post("/api/voices", json=sample_voice_data)
|
|
|
|
response = client.get("/api/voices?gender=Female")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
for voice in data["list"]:
|
|
assert voice["gender"] == "Female"
|
|
|
|
def test_preview_voice_success(self, client, monkeypatch):
|
|
"""Test preview voice endpoint returns audio data URL"""
|
|
from app.routers import voices as voice_router
|
|
|
|
class DummyResponse:
|
|
status_code = 200
|
|
content = b"fake-mp3-bytes"
|
|
text = "ok"
|
|
|
|
def json(self):
|
|
return {}
|
|
|
|
class DummyClient:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def post(self, *args, **kwargs):
|
|
return DummyResponse()
|
|
|
|
monkeypatch.setenv("SILICONFLOW_API_KEY", "test-key")
|
|
monkeypatch.setattr(voice_router.httpx, "Client", DummyClient)
|
|
|
|
create_resp = client.post("/api/voices", json={
|
|
"id": "anna",
|
|
"name": "Anna",
|
|
"vendor": "SiliconFlow",
|
|
"gender": "Female",
|
|
"language": "zh",
|
|
"description": "system voice",
|
|
"model": "FunAudioLLM/CosyVoice2-0.5B",
|
|
"voice_key": "FunAudioLLM/CosyVoice2-0.5B:anna"
|
|
})
|
|
assert create_resp.status_code == 200
|
|
|
|
preview_resp = client.post("/api/voices/anna/preview", json={"text": "你好"})
|
|
assert preview_resp.status_code == 200
|
|
payload = preview_resp.json()
|
|
assert payload["success"] is True
|
|
assert payload["audio_url"].startswith("data:audio/mpeg;base64,")
|
|
encoded = payload["audio_url"].split(",", 1)[1]
|
|
assert base64.b64decode(encoded) == b"fake-mp3-bytes"
|
|
|
|
def test_vendor_credential_persist_and_preview_use_db_key(self, client, monkeypatch):
|
|
"""Test vendor credential persisted in DB and used by preview endpoint"""
|
|
from app.routers import voices as voice_router
|
|
|
|
captured_auth = {"value": ""}
|
|
|
|
class DummyResponse:
|
|
status_code = 200
|
|
content = b"fake-mp3"
|
|
text = "ok"
|
|
|
|
def json(self):
|
|
return {}
|
|
|
|
class DummyClient:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def post(self, *args, **kwargs):
|
|
headers = kwargs.get("headers", {})
|
|
captured_auth["value"] = headers.get("Authorization", "")
|
|
return DummyResponse()
|
|
|
|
monkeypatch.delenv("SILICONFLOW_API_KEY", raising=False)
|
|
monkeypatch.setattr(voice_router.httpx, "Client", DummyClient)
|
|
|
|
save_cred = client.put(
|
|
"/api/voices/vendors/credentials/siliconflow",
|
|
json={
|
|
"vendor_name": "SiliconFlow",
|
|
"api_key": "db-key-123",
|
|
"base_url": "https://api.siliconflow.cn/v1"
|
|
},
|
|
)
|
|
assert save_cred.status_code == 200
|
|
assert save_cred.json()["vendor_key"] == "siliconflow"
|
|
|
|
create_resp = client.post("/api/voices", json={
|
|
"id": "anna2",
|
|
"name": "Anna 2",
|
|
"vendor": "SiliconFlow",
|
|
"gender": "Female",
|
|
"language": "zh",
|
|
"description": "voice",
|
|
"model": "FunAudioLLM/CosyVoice2-0.5B",
|
|
"voice_key": "FunAudioLLM/CosyVoice2-0.5B:anna"
|
|
})
|
|
assert create_resp.status_code == 200
|
|
|
|
preview_resp = client.post("/api/voices/anna2/preview", json={"text": "hello"})
|
|
assert preview_resp.status_code == 200
|
|
assert captured_auth["value"] == "Bearer db-key-123"
|