import base64 import os import uuid from datetime import datetime from typing import Optional import httpx from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from ..db import get_db from ..models import VendorCredential, Voice from ..schemas import ( VendorCredentialOut, VendorCredentialUpsert, VoiceCreate, VoiceOut, VoicePreviewRequest, VoicePreviewResponse, VoiceUpdate, ) router = APIRouter(prefix="/voices", tags=["Voices"]) SILICONFLOW_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B" def _is_siliconflow_vendor(vendor: str) -> bool: return vendor.strip().lower() in {"siliconflow", "硅基流动"} def _canonical_vendor_key(vendor: str) -> str: normalized = vendor.strip().lower() alias_map = { "硅基流动": "siliconflow", "siliconflow": "siliconflow", "ali": "ali", "volcano": "volcano", "minimax": "minimax", } return alias_map.get(normalized, normalized) def _default_tts_base_url(vendor_key: str) -> Optional[str]: defaults = { "siliconflow": "https://api.siliconflow.cn/v1", } return defaults.get(vendor_key) def _resolve_vendor_credential(db: Session, vendor: str) -> Optional[VendorCredential]: vendor_key = _canonical_vendor_key(vendor) return db.query(VendorCredential).filter(VendorCredential.vendor_key == vendor_key).first() def _build_siliconflow_voice_key(voice: Voice, model: str) -> str: if voice.voice_key: return voice.voice_key if ":" in voice.id: return voice.id return f"{model}:{voice.id}" @router.get("") def list_voices( vendor: Optional[str] = None, language: Optional[str] = None, gender: Optional[str] = None, page: int = 1, limit: int = 50, db: Session = Depends(get_db) ): """获取声音库列表""" query = db.query(Voice) if vendor: query = query.filter(Voice.vendor == vendor) if language: query = query.filter(Voice.language == language) if gender: query = query.filter(Voice.gender == gender) total = query.count() voices = query.order_by(Voice.created_at.desc()) \ .offset((page - 1) * limit).limit(limit).all() return {"total": total, "page": page, "limit": limit, "list": voices} @router.post("", response_model=VoiceOut) def create_voice(data: VoiceCreate, db: Session = Depends(get_db)): """创建声音""" vendor = data.vendor.strip() model = data.model voice_key = data.voice_key if _is_siliconflow_vendor(vendor): model = model or SILICONFLOW_DEFAULT_MODEL if not voice_key: raw_id = (data.id or data.name).strip() voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}" voice = Voice( id=data.id or str(uuid.uuid4())[:8], user_id=1, name=data.name, vendor=vendor, gender=data.gender, language=data.language, description=data.description, model=model, voice_key=voice_key, speed=data.speed, gain=data.gain, pitch=data.pitch, enabled=data.enabled, ) db.add(voice) db.commit() db.refresh(voice) return voice @router.get("/{id}", response_model=VoiceOut) def get_voice(id: str, db: Session = Depends(get_db)): """获取单个声音详情""" voice = db.query(Voice).filter(Voice.id == id).first() if not voice: raise HTTPException(status_code=404, detail="Voice not found") return voice @router.put("/{id}", response_model=VoiceOut) def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)): """更新声音""" voice = db.query(Voice).filter(Voice.id == id).first() if not voice: raise HTTPException(status_code=404, detail="Voice not found") update_data = data.model_dump(exclude_unset=True) if "vendor" in update_data and update_data["vendor"] is not None: update_data["vendor"] = update_data["vendor"].strip() vendor_for_defaults = update_data.get("vendor", voice.vendor) if _is_siliconflow_vendor(vendor_for_defaults): model = update_data.get("model") or voice.model or SILICONFLOW_DEFAULT_MODEL voice_key = update_data.get("voice_key") or voice.voice_key update_data["model"] = model update_data["voice_key"] = voice_key or _build_siliconflow_voice_key(voice, model) for field, value in update_data.items(): setattr(voice, field, value) db.commit() db.refresh(voice) return voice @router.delete("/{id}") def delete_voice(id: str, db: Session = Depends(get_db)): """删除声音""" voice = db.query(Voice).filter(Voice.id == id).first() if not voice: raise HTTPException(status_code=404, detail="Voice not found") db.delete(voice) db.commit() return {"message": "Deleted successfully"} @router.get("/vendors/credentials") def list_vendor_credentials(db: Session = Depends(get_db)): items = db.query(VendorCredential).order_by(VendorCredential.updated_at.desc()).all() return {"list": items, "total": len(items)} @router.get("/vendors/credentials/{vendor_key}", response_model=VendorCredentialOut) def get_vendor_credential(vendor_key: str, db: Session = Depends(get_db)): key = _canonical_vendor_key(vendor_key) item = db.query(VendorCredential).filter(VendorCredential.vendor_key == key).first() if not item: raise HTTPException(status_code=404, detail="Vendor credential not found") return item @router.put("/vendors/credentials/{vendor_key}", response_model=VendorCredentialOut) def upsert_vendor_credential(vendor_key: str, data: VendorCredentialUpsert, db: Session = Depends(get_db)): key = _canonical_vendor_key(vendor_key) item = db.query(VendorCredential).filter(VendorCredential.vendor_key == key).first() if item: item.vendor_name = data.vendor_name or item.vendor_name item.api_key = data.api_key item.base_url = data.base_url item.updated_at = datetime.utcnow() else: item = VendorCredential( vendor_key=key, vendor_name=data.vendor_name or vendor_key, api_key=data.api_key, base_url=data.base_url, ) db.add(item) db.commit() db.refresh(item) return item @router.delete("/vendors/credentials/{vendor_key}") def delete_vendor_credential(vendor_key: str, db: Session = Depends(get_db)): key = _canonical_vendor_key(vendor_key) item = db.query(VendorCredential).filter(VendorCredential.vendor_key == key).first() if not item: raise HTTPException(status_code=404, detail="Vendor credential not found") db.delete(item) db.commit() return {"message": "Deleted successfully"} @router.post("/{id}/preview", response_model=VoicePreviewResponse) def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_db)): """试听指定声音,基于 OpenAI-compatible /audio/speech 接口。""" voice = db.query(Voice).filter(Voice.id == id).first() if not voice: raise HTTPException(status_code=404, detail="Voice not found") text = data.text.strip() if not text: raise HTTPException(status_code=400, detail="Preview text cannot be empty") credential = _resolve_vendor_credential(db, voice.vendor) api_key = (data.api_key or "").strip() if not api_key and credential: api_key = credential.api_key if not api_key: api_key = os.getenv("SILICONFLOW_API_KEY") if _is_siliconflow_vendor(voice.vendor) else "" if not api_key: raise HTTPException(status_code=400, detail=f"Vendor API key is required for {voice.vendor}") model = voice.model or SILICONFLOW_DEFAULT_MODEL vendor_key = _canonical_vendor_key(voice.vendor) base_url = (credential.base_url.strip() if credential and credential.base_url else "") or _default_tts_base_url(vendor_key) if not base_url: raise HTTPException(status_code=400, detail=f"Vendor base_url is required for {voice.vendor}") tts_api_url = f"{base_url.rstrip('/')}/audio/speech" payload = { "model": model, "input": text, "voice": voice.voice_key or _build_siliconflow_voice_key(voice, model), "response_format": "mp3", "speed": data.speed if data.speed is not None else voice.speed, } try: with httpx.Client(timeout=45.0) as client: response = client.post( tts_api_url, headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, json=payload, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"TTS request failed: {exc}") from exc if response.status_code != 200: detail = response.text try: detail_json = response.json() detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail except Exception: pass raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}") audio_base64 = base64.b64encode(response.content).decode("utf-8") return VoicePreviewResponse(success=True, audio_url=f"data:audio/mpeg;base64,{audio_base64}")