Files
AI-VideoAssistant/api/app/routers/voices.py
2026-02-26 03:54:52 +08:00

443 lines
15 KiB
Python

import base64
import io
import json
import os
import threading
import wave
from typing import Any, Dict, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from ..db import get_db
from ..id_generator import unique_short_id
from ..models import Voice
from ..schemas import VoiceCreate, VoiceOut, VoicePreviewRequest, VoicePreviewResponse, VoiceUpdate
router = APIRouter(prefix="/voices", tags=["Voices"])
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
DASHSCOPE_DEFAULT_MODEL = "qwen3-tts-flash-realtime"
DASHSCOPE_DEFAULT_VOICE_KEY = "Cherry"
DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
try:
import dashscope
from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback
DASHSCOPE_SDK_AVAILABLE = True
except ImportError:
dashscope = None # type: ignore[assignment]
AudioFormat = None # type: ignore[assignment]
QwenTtsRealtime = None # type: ignore[assignment]
DASHSCOPE_SDK_AVAILABLE = False
class QwenTtsRealtimeCallback: # type: ignore[no-redef]
"""Fallback callback base when DashScope SDK is unavailable."""
pass
class _DashScopePreviewCallback(QwenTtsRealtimeCallback):
"""Collect DashScope realtime callback events and PCM chunks."""
def __init__(self) -> None:
super().__init__()
self._open_event = threading.Event()
self._done_event = threading.Event()
self._lock = threading.Lock()
self._audio_chunks: list[bytes] = []
self._error_message: Optional[str] = None
def on_open(self) -> None:
self._open_event.set()
def on_close(self, code: int, reason: str) -> None:
if not self._done_event.is_set():
self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}"
self._done_event.set()
def on_error(self, message: str) -> None:
self._error_message = str(message)
self._done_event.set()
def on_event(self, response: Any) -> None:
payload = _coerce_dashscope_event(response)
event_type = str(payload.get("type") or "").strip()
if event_type == "response.audio.delta":
delta = payload.get("delta")
if isinstance(delta, str):
try:
self._append_audio(base64.b64decode(delta))
except Exception:
return
elif event_type in {"response.done", "session.finished"}:
self._done_event.set()
elif event_type == "error":
self._error_message = _format_dashscope_error_event(payload)
self._done_event.set()
def on_data(self, data: bytes) -> None:
# Some SDK versions emit raw PCM frames via on_data.
if isinstance(data, (bytes, bytearray)):
self._append_audio(bytes(data))
def wait_for_open(self, timeout: float = 10.0) -> None:
if not self._open_event.wait(timeout):
raise TimeoutError("DashScope websocket open timeout")
def wait_for_done(self, timeout: float = 45.0) -> None:
if not self._done_event.wait(timeout):
raise TimeoutError("DashScope synthesis timeout")
def raise_if_error(self) -> None:
if self._error_message:
raise RuntimeError(self._error_message)
def read_audio(self) -> bytes:
with self._lock:
return b"".join(self._audio_chunks)
def _append_audio(self, chunk: bytes) -> None:
if not chunk:
return
with self._lock:
self._audio_chunks.append(chunk)
def _coerce_dashscope_event(response: Any) -> Dict[str, Any]:
if isinstance(response, dict):
return response
if isinstance(response, str):
try:
parsed = json.loads(response)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return {"type": "raw", "message": str(response)}
def _format_dashscope_error_event(payload: Dict[str, Any]) -> str:
error = payload.get("error")
if isinstance(error, dict):
code = str(error.get("code") or "").strip()
message = str(error.get("message") or "").strip()
if code and message:
return f"{code}: {message}"
return message or str(error)
return str(error or "DashScope realtime TTS error")
def _create_dashscope_realtime_client(*, model: str, callback: _DashScopePreviewCallback, url: str, api_key: str) -> Any:
if QwenTtsRealtime is None:
raise RuntimeError("DashScope SDK unavailable")
init_kwargs = {
"model": model,
"callback": callback,
"url": url,
}
try:
return QwenTtsRealtime(api_key=api_key, **init_kwargs) # type: ignore[misc]
except TypeError as exc:
if "api_key" not in str(exc):
raise
return QwenTtsRealtime(**init_kwargs) # type: ignore[misc]
def _pcm16_to_wav_bytes(pcm_bytes: bytes, sample_rate: int = 24000) -> bytes:
with io.BytesIO() as buffer:
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_bytes)
return buffer.getvalue()
def _synthesize_dashscope_preview(
*,
text: str,
api_key: str,
base_url: str,
model: str,
voice_key: str,
speed: Optional[float],
) -> bytes:
if not DASHSCOPE_SDK_AVAILABLE:
raise RuntimeError("dashscope package not installed; install with `pip install dashscope>=1.25.11`")
if not AudioFormat:
raise RuntimeError("DashScope SDK AudioFormat unavailable")
callback = _DashScopePreviewCallback()
if dashscope is not None:
dashscope.api_key = api_key
client = _create_dashscope_realtime_client(
model=model,
callback=callback,
url=base_url,
api_key=api_key,
)
try:
client.connect()
callback.wait_for_open()
session_kwargs: Dict[str, Any] = {
"voice": voice_key,
"response_format": AudioFormat.PCM_24000HZ_MONO_16BIT,
"mode": "commit",
}
# speech_rate is supported by qwen3-* realtime models.
normalized_model = str(model or "").strip().lower()
if speed is not None and normalized_model.startswith("qwen3-"):
session_kwargs["speech_rate"] = max(0.5, min(2.0, float(speed)))
client.update_session(**session_kwargs)
client.append_text(text)
client.commit()
callback.wait_for_done()
callback.raise_if_error()
pcm_audio = callback.read_audio()
if not pcm_audio:
raise RuntimeError("No audio chunk returned from DashScope realtime synthesis")
return _pcm16_to_wav_bytes(pcm_audio, sample_rate=24000)
finally:
finish_fn = getattr(client, "finish", None)
if callable(finish_fn):
try:
finish_fn()
except Exception:
pass
close_fn = getattr(client, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
def _is_openai_compatible_vendor(vendor: str) -> bool:
normalized = (vendor or "").strip().lower()
return normalized in {
"openai compatible",
"openai-compatible",
"siliconflow", # backward compatibility
"硅基流动", # backward compatibility
}
def _is_dashscope_vendor(vendor: str) -> bool:
normalized = (vendor or "").strip().lower()
return normalized in {
"dashscope",
}
def _default_base_url(vendor: str) -> Optional[str]:
if _is_openai_compatible_vendor(vendor):
return "https://api.siliconflow.cn/v1"
if _is_dashscope_vendor(vendor):
return DASHSCOPE_DEFAULT_BASE_URL
return None
def _build_openai_compatible_voice_key(voice: Voice, model: str) -> str:
if voice.voice_key:
return voice.voice_key
if ":" in voice.id:
return voice.id
return f"{model}:{voice.id}"
@router.get("")
def list_voices(
vendor: Optional[str] = None,
language: Optional[str] = None,
gender: Optional[str] = None,
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取声音库列表"""
query = db.query(Voice)
if vendor:
query = query.filter(Voice.vendor == vendor)
if language:
query = query.filter(Voice.language == language)
if gender:
query = query.filter(Voice.gender == gender)
total = query.count()
voices = query.order_by(Voice.created_at.desc()) \
.offset((page - 1) * limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": voices}
@router.post("", response_model=VoiceOut)
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
"""创建声音"""
vendor = data.vendor.strip()
model = data.model
voice_key = data.voice_key
if _is_openai_compatible_vendor(vendor):
model = model or OPENAI_COMPATIBLE_DEFAULT_MODEL
if not voice_key:
raw_id = (data.id or data.name).strip()
voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}"
elif _is_dashscope_vendor(vendor):
model = (model or "").strip() or DASHSCOPE_DEFAULT_MODEL
voice_key = (voice_key or "").strip() or DASHSCOPE_DEFAULT_VOICE_KEY
voice = Voice(
id=unique_short_id("tts", db, Voice),
user_id=1,
name=data.name,
vendor=vendor,
gender=data.gender,
language=data.language,
description=data.description,
model=model,
voice_key=voice_key,
api_key=data.api_key,
base_url=data.base_url,
speed=data.speed,
gain=data.gain,
pitch=data.pitch,
enabled=data.enabled,
)
db.add(voice)
db.commit()
db.refresh(voice)
return voice
@router.get("/{id}", response_model=VoiceOut)
def get_voice(id: str, db: Session = Depends(get_db)):
"""获取单个声音详情"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
return voice
@router.put("/{id}", response_model=VoiceOut)
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
"""更新声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
update_data = data.model_dump(exclude_unset=True)
if "vendor" in update_data and update_data["vendor"] is not None:
update_data["vendor"] = update_data["vendor"].strip()
vendor_for_defaults = update_data.get("vendor", voice.vendor)
if _is_openai_compatible_vendor(vendor_for_defaults):
model = update_data.get("model") or voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
voice_key = update_data.get("voice_key") or voice.voice_key
update_data["model"] = model
update_data["voice_key"] = voice_key or _build_openai_compatible_voice_key(voice, model)
elif _is_dashscope_vendor(vendor_for_defaults):
model = update_data.get("model") or voice.model or DASHSCOPE_DEFAULT_MODEL
voice_key = update_data.get("voice_key") or voice.voice_key or DASHSCOPE_DEFAULT_VOICE_KEY
update_data["model"] = model
update_data["voice_key"] = voice_key
for field, value in update_data.items():
setattr(voice, field, value)
db.commit()
db.refresh(voice)
return voice
@router.delete("/{id}")
def delete_voice(id: str, db: Session = Depends(get_db)):
"""删除声音"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
db.delete(voice)
db.commit()
return {"message": "Deleted successfully"}
@router.post("/{id}/preview", response_model=VoicePreviewResponse)
def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_db)):
"""试听指定声音,支持 OpenAI-compatible 与 DashScope Realtime。"""
voice = db.query(Voice).filter(Voice.id == id).first()
if not voice:
raise HTTPException(status_code=404, detail="Voice not found")
text = data.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Preview text cannot be empty")
if _is_dashscope_vendor(voice.vendor):
api_key = (data.api_key or "").strip() or (voice.api_key or "").strip()
if not api_key:
api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("TTS_API_KEY", "").strip()
if not api_key:
raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}")
base_url = (voice.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL
model = (voice.model or "").strip() or DASHSCOPE_DEFAULT_MODEL
voice_key = (voice.voice_key or "").strip() or DASHSCOPE_DEFAULT_VOICE_KEY
effective_speed = data.speed if data.speed is not None else voice.speed
try:
wav_bytes = _synthesize_dashscope_preview(
text=text,
api_key=api_key,
base_url=base_url,
model=model,
voice_key=voice_key,
speed=effective_speed,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"DashScope preview failed: {exc}") from exc
audio_base64 = base64.b64encode(wav_bytes).decode("utf-8")
return VoicePreviewResponse(success=True, audio_url=f"data:audio/wav;base64,{audio_base64}")
api_key = (data.api_key or "").strip() or (voice.api_key or "").strip()
if not api_key and _is_openai_compatible_vendor(voice.vendor):
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not api_key:
raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}")
base_url = (voice.base_url or "").strip() or (_default_base_url(voice.vendor) or "")
if not base_url:
raise HTTPException(status_code=400, detail=f"Base URL is required for voice: {voice.name}")
model = voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
payload = {
"model": model,
"input": text,
"voice": voice.voice_key or _build_openai_compatible_voice_key(voice, model),
"response_format": "mp3",
"speed": data.speed if data.speed is not None else voice.speed,
}
try:
with httpx.Client(timeout=45.0) as client:
response = client.post(
f"{base_url.rstrip('/')}/audio/speech",
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
json=payload,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"TTS request failed: {exc}") from exc
if response.status_code != 200:
detail = response.text
try:
detail_json = response.json()
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
except Exception:
pass
raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}")
audio_base64 = base64.b64encode(response.content).decode("utf-8")
return VoicePreviewResponse(success=True, audio_url=f"data:audio/mpeg;base64,{audio_base64}")