Files
AI-VideoAssistant/api/app/routers/voices.py
2026-02-08 23:16:21 +08:00

194 lines
6.4 KiB
Python

import base64
import os
import uuid
from typing import Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from ..db import get_db
from ..models import Voice
from ..schemas import VoiceCreate, VoiceOut, VoicePreviewRequest, VoicePreviewResponse, VoiceUpdate
router = APIRouter(prefix="/voices", tags=["Voices"])
SILICONFLOW_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
def _is_siliconflow_vendor(vendor: str) -> bool:
return vendor.strip().lower() in {"siliconflow", "硅基流动"}
def _default_base_url(vendor: str) -> Optional[str]:
if _is_siliconflow_vendor(vendor):
return "https://api.siliconflow.cn/v1"
return None
def _build_siliconflow_voice_key(voice: Voice, model: str) -> str:
if voice.voice_key:
return voice.voice_key
if ":" in voice.id:
return voice.id
return f"{model}:{voice.id}"
@router.get("")
def list_voices(
vendor: Optional[str] = None,
language: Optional[str] = None,
gender: Optional[str] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取声音库列表"""
query = db.query(Voice)
if vendor:
query = query.filter(Voice.vendor == vendor)
if language:
query = query.filter(Voice.language == language)
if gender:
query = query.filter(Voice.gender == gender)
total = query.count()
voices = query.order_by(Voice.created_at.desc()) \
.offset((page - 1) * limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": voices}
@router.post("", response_model=VoiceOut)
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
"""创建声音"""
vendor = data.vendor.strip()
model = data.model
voice_key = data.voice_key
if _is_siliconflow_vendor(vendor):
model = model or SILICONFLOW_DEFAULT_MODEL
if not voice_key:
raw_id = (data.id or data.name).strip()
voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}"
voice = Voice(
id=data.id or str(uuid.uuid4())[:8],
user_id=1,
name=data.name,
vendor=vendor,
gender=data.gender,
language=data.language,
description=data.description,
model=model,
voice_key=voice_key,
api_key=data.api_key,
base_url=data.base_url,
speed=data.speed,
gain=data.gain,
pitch=data.pitch,
enabled=data.enabled,
)
db.add(voice)
db.commit()
db.refresh(voice)
return voice
@router.get("/{id}", response_model=VoiceOut)
def get_voice(id: str, db: Session = Depends(get_db)):
"""获取单个声音详情"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
return voice
@router.put("/{id}", response_model=VoiceOut)
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
"""更新声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
update_data = data.model_dump(exclude_unset=True)
if "vendor" in update_data and update_data["vendor"] is not None:
update_data["vendor"] = update_data["vendor"].strip()
vendor_for_defaults = update_data.get("vendor", voice.vendor)
if _is_siliconflow_vendor(vendor_for_defaults):
model = update_data.get("model") or voice.model or SILICONFLOW_DEFAULT_MODEL
voice_key = update_data.get("voice_key") or voice.voice_key
update_data["model"] = model
update_data["voice_key"] = voice_key or _build_siliconflow_voice_key(voice, model)
for field, value in update_data.items():
setattr(voice, field, value)
db.commit()
db.refresh(voice)
return voice
@router.delete("/{id}")
def delete_voice(id: str, db: Session = Depends(get_db)):
"""删除声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
db.delete(voice)
db.commit()
return {"message": "Deleted successfully"}
@router.post("/{id}/preview", response_model=VoicePreviewResponse)
def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_db)):
"""试听指定声音,基于 OpenAI-compatible /audio/speech 接口。"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
text = data.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Preview text cannot be empty")
api_key = (data.api_key or "").strip() or (voice.api_key or "").strip()
if not api_key and _is_siliconflow_vendor(voice.vendor):
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not api_key:
raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}")
base_url = (voice.base_url or "").strip() or (_default_base_url(voice.vendor) or "")
if not base_url:
raise HTTPException(status_code=400, detail=f"Base URL is required for voice: {voice.name}")
model = voice.model or SILICONFLOW_DEFAULT_MODEL
payload = {
"model": model,
"input": text,
"voice": voice.voice_key or _build_siliconflow_voice_key(voice, model),
"response_format": "mp3",
"speed": data.speed if data.speed is not None else voice.speed,
}
try:
with httpx.Client(timeout=45.0) as client:
response = client.post(
f"{base_url.rstrip('/')}/audio/speech",
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
json=payload,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"TTS request failed: {exc}") from exc
if response.status_code != 200:
detail = response.text
try:
detail_json = response.json()
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
except Exception:
pass
raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}")
audio_base64 = base64.b64encode(response.content).decode("utf-8")
return VoicePreviewResponse(success=True, audio_url=f"data:audio/mpeg;base64,{audio_base64}")