274 lines
9.1 KiB
Python
274 lines
9.1 KiB
Python
import base64
|
|
import os
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ..db import get_db
|
|
from ..models import VendorCredential, Voice
|
|
from ..schemas import (
|
|
VendorCredentialOut,
|
|
VendorCredentialUpsert,
|
|
VoiceCreate,
|
|
VoiceOut,
|
|
VoicePreviewRequest,
|
|
VoicePreviewResponse,
|
|
VoiceUpdate,
|
|
)
|
|
|
|
router = APIRouter(prefix="/voices", tags=["Voices"])
|
|
|
|
SILICONFLOW_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
|
|
|
|
|
def _is_siliconflow_vendor(vendor: str) -> bool:
|
|
return vendor.strip().lower() in {"siliconflow", "硅基流动"}
|
|
|
|
|
|
def _canonical_vendor_key(vendor: str) -> str:
|
|
normalized = vendor.strip().lower()
|
|
alias_map = {
|
|
"硅基流动": "siliconflow",
|
|
"siliconflow": "siliconflow",
|
|
"ali": "ali",
|
|
"volcano": "volcano",
|
|
"minimax": "minimax",
|
|
}
|
|
return alias_map.get(normalized, normalized)
|
|
|
|
|
|
def _default_tts_base_url(vendor_key: str) -> Optional[str]:
|
|
defaults = {
|
|
"siliconflow": "https://api.siliconflow.cn/v1",
|
|
}
|
|
return defaults.get(vendor_key)
|
|
|
|
|
|
def _resolve_vendor_credential(db: Session, vendor: str) -> Optional[VendorCredential]:
|
|
vendor_key = _canonical_vendor_key(vendor)
|
|
return db.query(VendorCredential).filter(VendorCredential.vendor_key == vendor_key).first()
|
|
|
|
|
|
def _build_siliconflow_voice_key(voice: Voice, model: str) -> str:
|
|
if voice.voice_key:
|
|
return voice.voice_key
|
|
if ":" in voice.id:
|
|
return voice.id
|
|
return f"{model}:{voice.id}"
|
|
|
|
|
|
@router.get("")
|
|
def list_voices(
|
|
vendor: Optional[str] = None,
|
|
language: Optional[str] = None,
|
|
gender: Optional[str] = None,
|
|
page: int = 1,
|
|
limit: int = 50,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取声音库列表"""
|
|
query = db.query(Voice)
|
|
if vendor:
|
|
query = query.filter(Voice.vendor == vendor)
|
|
if language:
|
|
query = query.filter(Voice.language == language)
|
|
if gender:
|
|
query = query.filter(Voice.gender == gender)
|
|
|
|
total = query.count()
|
|
voices = query.order_by(Voice.created_at.desc()) \
|
|
.offset((page - 1) * limit).limit(limit).all()
|
|
return {"total": total, "page": page, "limit": limit, "list": voices}
|
|
|
|
|
|
@router.post("", response_model=VoiceOut)
|
|
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
|
"""创建声音"""
|
|
vendor = data.vendor.strip()
|
|
model = data.model
|
|
voice_key = data.voice_key
|
|
|
|
if _is_siliconflow_vendor(vendor):
|
|
model = model or SILICONFLOW_DEFAULT_MODEL
|
|
if not voice_key:
|
|
raw_id = (data.id or data.name).strip()
|
|
voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}"
|
|
|
|
voice = Voice(
|
|
id=data.id or str(uuid.uuid4())[:8],
|
|
user_id=1,
|
|
name=data.name,
|
|
vendor=vendor,
|
|
gender=data.gender,
|
|
language=data.language,
|
|
description=data.description,
|
|
model=model,
|
|
voice_key=voice_key,
|
|
speed=data.speed,
|
|
gain=data.gain,
|
|
pitch=data.pitch,
|
|
enabled=data.enabled,
|
|
)
|
|
db.add(voice)
|
|
db.commit()
|
|
db.refresh(voice)
|
|
return voice
|
|
|
|
|
|
@router.get("/{id}", response_model=VoiceOut)
|
|
def get_voice(id: str, db: Session = Depends(get_db)):
|
|
"""获取单个声音详情"""
|
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
if not voice:
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
return voice
|
|
|
|
|
|
@router.put("/{id}", response_model=VoiceOut)
|
|
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
|
"""更新声音"""
|
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
if not voice:
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
if "vendor" in update_data and update_data["vendor"] is not None:
|
|
update_data["vendor"] = update_data["vendor"].strip()
|
|
|
|
vendor_for_defaults = update_data.get("vendor", voice.vendor)
|
|
if _is_siliconflow_vendor(vendor_for_defaults):
|
|
model = update_data.get("model") or voice.model or SILICONFLOW_DEFAULT_MODEL
|
|
voice_key = update_data.get("voice_key") or voice.voice_key
|
|
update_data["model"] = model
|
|
update_data["voice_key"] = voice_key or _build_siliconflow_voice_key(voice, model)
|
|
|
|
for field, value in update_data.items():
|
|
setattr(voice, field, value)
|
|
|
|
db.commit()
|
|
db.refresh(voice)
|
|
return voice
|
|
|
|
|
|
@router.delete("/{id}")
|
|
def delete_voice(id: str, db: Session = Depends(get_db)):
|
|
"""删除声音"""
|
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
if not voice:
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
db.delete(voice)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
@router.get("/vendors/credentials")
|
|
def list_vendor_credentials(db: Session = Depends(get_db)):
|
|
items = db.query(VendorCredential).order_by(VendorCredential.updated_at.desc()).all()
|
|
return {"list": items, "total": len(items)}
|
|
|
|
|
|
@router.get("/vendors/credentials/{vendor_key}", response_model=VendorCredentialOut)
|
|
def get_vendor_credential(vendor_key: str, db: Session = Depends(get_db)):
|
|
key = _canonical_vendor_key(vendor_key)
|
|
item = db.query(VendorCredential).filter(VendorCredential.vendor_key == key).first()
|
|
if not item:
|
|
raise HTTPException(status_code=404, detail="Vendor credential not found")
|
|
return item
|
|
|
|
|
|
@router.put("/vendors/credentials/{vendor_key}", response_model=VendorCredentialOut)
|
|
def upsert_vendor_credential(vendor_key: str, data: VendorCredentialUpsert, db: Session = Depends(get_db)):
|
|
key = _canonical_vendor_key(vendor_key)
|
|
item = db.query(VendorCredential).filter(VendorCredential.vendor_key == key).first()
|
|
|
|
if item:
|
|
item.vendor_name = data.vendor_name or item.vendor_name
|
|
item.api_key = data.api_key
|
|
item.base_url = data.base_url
|
|
item.updated_at = datetime.utcnow()
|
|
else:
|
|
item = VendorCredential(
|
|
vendor_key=key,
|
|
vendor_name=data.vendor_name or vendor_key,
|
|
api_key=data.api_key,
|
|
base_url=data.base_url,
|
|
)
|
|
db.add(item)
|
|
|
|
db.commit()
|
|
db.refresh(item)
|
|
return item
|
|
|
|
|
|
@router.delete("/vendors/credentials/{vendor_key}")
|
|
def delete_vendor_credential(vendor_key: str, db: Session = Depends(get_db)):
|
|
key = _canonical_vendor_key(vendor_key)
|
|
item = db.query(VendorCredential).filter(VendorCredential.vendor_key == key).first()
|
|
if not item:
|
|
raise HTTPException(status_code=404, detail="Vendor credential not found")
|
|
db.delete(item)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
@router.post("/{id}/preview", response_model=VoicePreviewResponse)
|
|
def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_db)):
|
|
"""试听指定声音,基于 OpenAI-compatible /audio/speech 接口。"""
|
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
if not voice:
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
|
|
text = data.text.strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="Preview text cannot be empty")
|
|
|
|
credential = _resolve_vendor_credential(db, voice.vendor)
|
|
api_key = (data.api_key or "").strip()
|
|
if not api_key and credential:
|
|
api_key = credential.api_key
|
|
if not api_key:
|
|
api_key = os.getenv("SILICONFLOW_API_KEY") if _is_siliconflow_vendor(voice.vendor) else ""
|
|
if not api_key:
|
|
raise HTTPException(status_code=400, detail=f"Vendor API key is required for {voice.vendor}")
|
|
|
|
model = voice.model or SILICONFLOW_DEFAULT_MODEL
|
|
vendor_key = _canonical_vendor_key(voice.vendor)
|
|
base_url = (credential.base_url.strip() if credential and credential.base_url else "") or _default_tts_base_url(vendor_key)
|
|
if not base_url:
|
|
raise HTTPException(status_code=400, detail=f"Vendor base_url is required for {voice.vendor}")
|
|
tts_api_url = f"{base_url.rstrip('/')}/audio/speech"
|
|
|
|
payload = {
|
|
"model": model,
|
|
"input": text,
|
|
"voice": voice.voice_key or _build_siliconflow_voice_key(voice, model),
|
|
"response_format": "mp3",
|
|
"speed": data.speed if data.speed is not None else voice.speed,
|
|
}
|
|
|
|
try:
|
|
with httpx.Client(timeout=45.0) as client:
|
|
response = client.post(
|
|
tts_api_url,
|
|
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
|
json=payload,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=502, detail=f"TTS request failed: {exc}") from exc
|
|
|
|
if response.status_code != 200:
|
|
detail = response.text
|
|
try:
|
|
detail_json = response.json()
|
|
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}")
|
|
|
|
audio_base64 = base64.b64encode(response.content).decode("utf-8")
|
|
return VoicePreviewResponse(success=True, audio_url=f"data:audio/mpeg;base64,{audio_base64}")
|