Files
AI-VideoAssistant/api/app/routers/voices.py
2026-02-12 19:05:50 +08:00

200 lines
6.7 KiB
Python

import base64
import os
from typing import Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from ..db import get_db
from ..id_generator import unique_short_id
from ..models import Voice
from ..schemas import VoiceCreate, VoiceOut, VoicePreviewRequest, VoicePreviewResponse, VoiceUpdate
router = APIRouter(prefix="/voices", tags=["Voices"])
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
def _is_openai_compatible_vendor(vendor: str) -> bool:
normalized = (vendor or "").strip().lower()
return normalized in {
"openai compatible",
"openai-compatible",
"siliconflow", # backward compatibility
"硅基流动", # backward compatibility
}
def _default_base_url(vendor: str) -> Optional[str]:
if _is_openai_compatible_vendor(vendor):
return "https://api.siliconflow.cn/v1"
return None
def _build_openai_compatible_voice_key(voice: Voice, model: str) -> str:
if voice.voice_key:
return voice.voice_key
if ":" in voice.id:
return voice.id
return f"{model}:{voice.id}"
@router.get("")
def list_voices(
vendor: Optional[str] = None,
language: Optional[str] = None,
gender: Optional[str] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取声音库列表"""
query = db.query(Voice)
if vendor:
query = query.filter(Voice.vendor == vendor)
if language:
query = query.filter(Voice.language == language)
if gender:
query = query.filter(Voice.gender == gender)
total = query.count()
voices = query.order_by(Voice.created_at.desc()) \
.offset((page - 1) * limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": voices}
@router.post("", response_model=VoiceOut)
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
"""创建声音"""
vendor = data.vendor.strip()
model = data.model
voice_key = data.voice_key
if _is_openai_compatible_vendor(vendor):
model = model or OPENAI_COMPATIBLE_DEFAULT_MODEL
if not voice_key:
raw_id = (data.id or data.name).strip()
voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}"
voice = Voice(
id=unique_short_id("tts", db, Voice),
user_id=1,
name=data.name,
vendor=vendor,
gender=data.gender,
language=data.language,
description=data.description,
model=model,
voice_key=voice_key,
api_key=data.api_key,
base_url=data.base_url,
speed=data.speed,
gain=data.gain,
pitch=data.pitch,
enabled=data.enabled,
)
db.add(voice)
db.commit()
db.refresh(voice)
return voice
@router.get("/{id}", response_model=VoiceOut)
def get_voice(id: str, db: Session = Depends(get_db)):
"""获取单个声音详情"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
return voice
@router.put("/{id}", response_model=VoiceOut)
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
"""更新声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
update_data = data.model_dump(exclude_unset=True)
if "vendor" in update_data and update_data["vendor"] is not None:
update_data["vendor"] = update_data["vendor"].strip()
vendor_for_defaults = update_data.get("vendor", voice.vendor)
if _is_openai_compatible_vendor(vendor_for_defaults):
model = update_data.get("model") or voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
voice_key = update_data.get("voice_key") or voice.voice_key
update_data["model"] = model
update_data["voice_key"] = voice_key or _build_openai_compatible_voice_key(voice, model)
for field, value in update_data.items():
setattr(voice, field, value)
db.commit()
db.refresh(voice)
return voice
@router.delete("/{id}")
def delete_voice(id: str, db: Session = Depends(get_db)):
"""删除声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
db.delete(voice)
db.commit()
return {"message": "Deleted successfully"}
@router.post("/{id}/preview", response_model=VoicePreviewResponse)
def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_db)):
"""试听指定声音,基于 OpenAI-compatible /audio/speech 接口。"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
text = data.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Preview text cannot be empty")
api_key = (data.api_key or "").strip() or (voice.api_key or "").strip()
if not api_key and _is_openai_compatible_vendor(voice.vendor):
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not api_key:
raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}")
base_url = (voice.base_url or "").strip() or (_default_base_url(voice.vendor) or "")
if not base_url:
raise HTTPException(status_code=400, detail=f"Base URL is required for voice: {voice.name}")
model = voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
payload = {
"model": model,
"input": text,
"voice": voice.voice_key or _build_openai_compatible_voice_key(voice, model),
"response_format": "mp3",
"speed": data.speed if data.speed is not None else voice.speed,
}
try:
with httpx.Client(timeout=45.0) as client:
response = client.post(
f"{base_url.rstrip('/')}/audio/speech",
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
json=payload,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"TTS request failed: {exc}") from exc
if response.status_code != 200:
detail = response.text
try:
detail_json = response.json()
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
except Exception:
pass
raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}")
audio_base64 = base64.b64encode(response.content).decode("utf-8")
return VoicePreviewResponse(success=True, audio_url=f"data:audio/mpeg;base64,{audio_base64}")