Files
AI-VideoAssistant/api/tests/test_voices.py
2026-02-26 03:54:52 +08:00

332 lines
12 KiB
Python

"""Tests for Voice API endpoints"""
import base64
import pytest
class TestVoiceAPI:
"""Test cases for Voice endpoints"""
def test_get_voices_empty(self, client):
"""Test getting voices when database is empty"""
response = client.get("/api/voices")
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "list" in data
def test_create_voice(self, client, sample_voice_data):
"""Test creating a new voice"""
response = client.post("/api/voices", json=sample_voice_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == sample_voice_data["name"]
assert data["vendor"] == sample_voice_data["vendor"]
assert data["gender"] == sample_voice_data["gender"]
assert data["language"] == sample_voice_data["language"]
assert "id" in data
def test_create_voice_minimal(self, client):
"""Test creating a voice with minimal data"""
data = {
"name": "Minimal Voice",
"vendor": "Test",
"gender": "Male",
"language": "en",
"description": ""
}
response = client.post("/api/voices", json=data)
assert response.status_code == 200
def test_get_voice_by_id(self, client, sample_voice_data):
"""Test getting a specific voice by ID"""
# Create first
create_response = client.post("/api/voices", json=sample_voice_data)
voice_id = create_response.json()["id"]
# Get by ID
response = client.get(f"/api/voices/{voice_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == voice_id
assert data["name"] == sample_voice_data["name"]
def test_get_voice_not_found(self, client):
"""Test getting a non-existent voice"""
response = client.get("/api/voices/non-existent-id")
assert response.status_code == 404
def test_update_voice(self, client, sample_voice_data):
"""Test updating a voice"""
# Create first
create_response = client.post("/api/voices", json=sample_voice_data)
voice_id = create_response.json()["id"]
# Update
update_data = {"name": "Updated Voice", "speed": 1.5}
response = client.put(f"/api/voices/{voice_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated Voice"
assert data["speed"] == 1.5
def test_delete_voice(self, client, sample_voice_data):
"""Test deleting a voice"""
# Create first
create_response = client.post("/api/voices", json=sample_voice_data)
voice_id = create_response.json()["id"]
# Delete
response = client.delete(f"/api/voices/{voice_id}")
assert response.status_code == 200
# Verify deleted
get_response = client.get(f"/api/voices/{voice_id}")
assert get_response.status_code == 404
def test_list_voices_with_pagination(self, client, sample_voice_data):
"""Test listing voices with pagination"""
# Create multiple voices
for i in range(3):
data = sample_voice_data.copy()
data["name"] = f"Voice {i}"
client.post("/api/voices", json=data)
# Test pagination
response = client.get("/api/voices?page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["list"]) == 2
def test_filter_voices_by_vendor(self, client, sample_voice_data):
"""Test filtering voices by vendor"""
# Create voice with specific vendor
sample_voice_data["vendor"] = "FilterTestVendor"
client.post("/api/voices", json=sample_voice_data)
response = client.get("/api/voices?vendor=FilterTestVendor")
assert response.status_code == 200
data = response.json()
for voice in data["list"]:
assert voice["vendor"] == "FilterTestVendor"
def test_filter_voices_by_language(self, client, sample_voice_data):
"""Test filtering voices by language"""
sample_voice_data["language"] = "en"
client.post("/api/voices", json=sample_voice_data)
response = client.get("/api/voices?language=en")
assert response.status_code == 200
data = response.json()
for voice in data["list"]:
assert voice["language"] == "en"
def test_filter_voices_by_gender(self, client, sample_voice_data):
"""Test filtering voices by gender"""
sample_voice_data["gender"] = "Female"
client.post("/api/voices", json=sample_voice_data)
response = client.get("/api/voices?gender=Female")
assert response.status_code == 200
data = response.json()
for voice in data["list"]:
assert voice["gender"] == "Female"
def test_preview_voice_success(self, client, monkeypatch):
"""Test preview voice endpoint returns audio data URL"""
from app.routers import voices as voice_router
class DummyResponse:
status_code = 200
content = b"fake-mp3-bytes"
text = "ok"
def json(self):
return {}
class DummyClient:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def post(self, *args, **kwargs):
return DummyResponse()
monkeypatch.setenv("SILICONFLOW_API_KEY", "test-key")
monkeypatch.setattr(voice_router.httpx, "Client", DummyClient)
create_resp = client.post("/api/voices", json={
"id": "anna",
"name": "Anna",
"vendor": "SiliconFlow",
"gender": "Female",
"language": "zh",
"description": "system voice",
"model": "FunAudioLLM/CosyVoice2-0.5B",
"voice_key": "FunAudioLLM/CosyVoice2-0.5B:anna"
})
assert create_resp.status_code == 200
voice_id = create_resp.json()["id"]
preview_resp = client.post(f"/api/voices/{voice_id}/preview", json={"text": "你好"})
assert preview_resp.status_code == 200
payload = preview_resp.json()
assert payload["success"] is True
assert payload["audio_url"].startswith("data:audio/mpeg;base64,")
encoded = payload["audio_url"].split(",", 1)[1]
assert base64.b64decode(encoded) == b"fake-mp3-bytes"
def test_voice_credential_persist_and_preview_use_voice_key(self, client, monkeypatch):
"""Test per-voice api_key/base_url persisted and used by preview endpoint"""
from app.routers import voices as voice_router
captured_auth = {"value": ""}
captured_url = {"value": ""}
class DummyResponse:
status_code = 200
content = b"fake-mp3"
text = "ok"
def json(self):
return {}
class DummyClient:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def post(self, *args, **kwargs):
headers = kwargs.get("headers", {})
captured_auth["value"] = headers.get("Authorization", "")
if args:
captured_url["value"] = args[0]
return DummyResponse()
monkeypatch.delenv("SILICONFLOW_API_KEY", raising=False)
monkeypatch.setattr(voice_router.httpx, "Client", DummyClient)
create_resp = client.post("/api/voices", json={
"id": "anna2",
"name": "Anna 2",
"vendor": "SiliconFlow",
"gender": "Female",
"language": "zh",
"description": "voice",
"model": "FunAudioLLM/CosyVoice2-0.5B",
"voice_key": "FunAudioLLM/CosyVoice2-0.5B:anna",
"api_key": "voice-key-123",
"base_url": "https://api.siliconflow.cn/v1"
})
assert create_resp.status_code == 200
voice_id = create_resp.json()["id"]
preview_resp = client.post(f"/api/voices/{voice_id}/preview", json={"text": "hello"})
assert preview_resp.status_code == 200
assert captured_auth["value"] == "Bearer voice-key-123"
assert captured_url["value"] == "https://api.siliconflow.cn/v1/audio/speech"
def test_create_voice_dashscope_defaults(self, client):
"""Test creating DashScope voice applies model/voice defaults."""
create_resp = client.post("/api/voices", json={
"name": "DashScope Voice",
"vendor": "DashScope",
"gender": "Female",
"language": "zh",
"description": "dashscope",
})
assert create_resp.status_code == 200
payload = create_resp.json()
assert payload["vendor"] == "DashScope"
assert payload["model"] == "qwen3-tts-flash-realtime"
assert payload["voice_key"] == "Cherry"
def test_preview_voice_dashscope_success(self, client, monkeypatch):
"""DashScope voice preview should return playable wav data url."""
from app.routers import voices as voice_router
captured = {
"api_key": "",
"model": "",
"url": "",
"session": {},
"text": "",
}
class DummyAudioFormat:
PCM_24000HZ_MONO_16BIT = "pcm24k16mono"
class DummyDashScopeModule:
api_key = ""
class DummyRealtime:
def __init__(self, *args, **kwargs):
captured["api_key"] = kwargs.get("api_key", "")
captured["model"] = kwargs.get("model", "")
captured["url"] = kwargs.get("url", "")
self.callback = kwargs["callback"]
def connect(self):
self.callback.on_open()
def update_session(self, **kwargs):
captured["session"] = kwargs
def append_text(self, text):
captured["text"] = text
def commit(self):
# 16-bit PCM mono samples
raw_pcm = b"\x00\x00\x01\x00\x02\x00\x03\x00"
self.callback.on_event({
"type": "response.audio.delta",
"delta": base64.b64encode(raw_pcm).decode("utf-8"),
})
self.callback.on_event({"type": "response.done"})
def finish(self):
return None
def close(self):
return None
monkeypatch.setattr(voice_router, "DASHSCOPE_SDK_AVAILABLE", True)
monkeypatch.setattr(voice_router, "AudioFormat", DummyAudioFormat)
monkeypatch.setattr(voice_router, "QwenTtsRealtime", DummyRealtime)
monkeypatch.setattr(voice_router, "dashscope", DummyDashScopeModule())
create_resp = client.post("/api/voices", json={
"name": "DashScope Voice",
"vendor": "DashScope",
"gender": "Female",
"language": "zh",
"description": "dashscope",
"api_key": "dashscope-key",
"base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime",
})
assert create_resp.status_code == 200
voice_id = create_resp.json()["id"]
preview_resp = client.post(f"/api/voices/{voice_id}/preview", json={"text": "你好"})
assert preview_resp.status_code == 200
payload = preview_resp.json()
assert payload["success"] is True
assert payload["audio_url"].startswith("data:audio/wav;base64,")
encoded = payload["audio_url"].split(",", 1)[1]
wav_bytes = base64.b64decode(encoded)
assert wav_bytes.startswith(b"RIFF")
assert captured["model"] == "qwen3-tts-flash-realtime"
assert captured["url"] == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
assert captured["text"] == "你好"
assert captured["session"]["voice"] == "Cherry"