443 lines
15 KiB
Python
443 lines
15 KiB
Python
import base64
|
|
import io
|
|
import json
|
|
import os
|
|
import threading
|
|
import wave
|
|
from typing import Any, Dict, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ..db import get_db
|
|
from ..id_generator import unique_short_id
|
|
from ..models import Voice
|
|
from ..schemas import VoiceCreate, VoiceOut, VoicePreviewRequest, VoicePreviewResponse, VoiceUpdate
|
|
|
|
router = APIRouter(prefix="/voices", tags=["Voices"])
|
|
|
|
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
|
DASHSCOPE_DEFAULT_MODEL = "qwen3-tts-flash-realtime"
|
|
DASHSCOPE_DEFAULT_VOICE_KEY = "Cherry"
|
|
DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
|
|
|
|
try:
|
|
import dashscope
|
|
from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback
|
|
|
|
DASHSCOPE_SDK_AVAILABLE = True
|
|
except ImportError:
|
|
dashscope = None # type: ignore[assignment]
|
|
AudioFormat = None # type: ignore[assignment]
|
|
QwenTtsRealtime = None # type: ignore[assignment]
|
|
DASHSCOPE_SDK_AVAILABLE = False
|
|
|
|
class QwenTtsRealtimeCallback: # type: ignore[no-redef]
|
|
"""Fallback callback base when DashScope SDK is unavailable."""
|
|
|
|
pass
|
|
|
|
|
|
class _DashScopePreviewCallback(QwenTtsRealtimeCallback):
|
|
"""Collect DashScope realtime callback events and PCM chunks."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._open_event = threading.Event()
|
|
self._done_event = threading.Event()
|
|
self._lock = threading.Lock()
|
|
self._audio_chunks: list[bytes] = []
|
|
self._error_message: Optional[str] = None
|
|
|
|
def on_open(self) -> None:
|
|
self._open_event.set()
|
|
|
|
def on_close(self, code: int, reason: str) -> None:
|
|
if not self._done_event.is_set():
|
|
self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}"
|
|
self._done_event.set()
|
|
|
|
def on_error(self, message: str) -> None:
|
|
self._error_message = str(message)
|
|
self._done_event.set()
|
|
|
|
def on_event(self, response: Any) -> None:
|
|
payload = _coerce_dashscope_event(response)
|
|
event_type = str(payload.get("type") or "").strip()
|
|
if event_type == "response.audio.delta":
|
|
delta = payload.get("delta")
|
|
if isinstance(delta, str):
|
|
try:
|
|
self._append_audio(base64.b64decode(delta))
|
|
except Exception:
|
|
return
|
|
elif event_type in {"response.done", "session.finished"}:
|
|
self._done_event.set()
|
|
elif event_type == "error":
|
|
self._error_message = _format_dashscope_error_event(payload)
|
|
self._done_event.set()
|
|
|
|
def on_data(self, data: bytes) -> None:
|
|
# Some SDK versions emit raw PCM frames via on_data.
|
|
if isinstance(data, (bytes, bytearray)):
|
|
self._append_audio(bytes(data))
|
|
|
|
def wait_for_open(self, timeout: float = 10.0) -> None:
|
|
if not self._open_event.wait(timeout):
|
|
raise TimeoutError("DashScope websocket open timeout")
|
|
|
|
def wait_for_done(self, timeout: float = 45.0) -> None:
|
|
if not self._done_event.wait(timeout):
|
|
raise TimeoutError("DashScope synthesis timeout")
|
|
|
|
def raise_if_error(self) -> None:
|
|
if self._error_message:
|
|
raise RuntimeError(self._error_message)
|
|
|
|
def read_audio(self) -> bytes:
|
|
with self._lock:
|
|
return b"".join(self._audio_chunks)
|
|
|
|
def _append_audio(self, chunk: bytes) -> None:
|
|
if not chunk:
|
|
return
|
|
with self._lock:
|
|
self._audio_chunks.append(chunk)
|
|
|
|
|
|
def _coerce_dashscope_event(response: Any) -> Dict[str, Any]:
|
|
if isinstance(response, dict):
|
|
return response
|
|
if isinstance(response, str):
|
|
try:
|
|
parsed = json.loads(response)
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"type": "raw", "message": str(response)}
|
|
|
|
|
|
def _format_dashscope_error_event(payload: Dict[str, Any]) -> str:
|
|
error = payload.get("error")
|
|
if isinstance(error, dict):
|
|
code = str(error.get("code") or "").strip()
|
|
message = str(error.get("message") or "").strip()
|
|
if code and message:
|
|
return f"{code}: {message}"
|
|
return message or str(error)
|
|
return str(error or "DashScope realtime TTS error")
|
|
|
|
|
|
def _create_dashscope_realtime_client(*, model: str, callback: _DashScopePreviewCallback, url: str, api_key: str) -> Any:
|
|
if QwenTtsRealtime is None:
|
|
raise RuntimeError("DashScope SDK unavailable")
|
|
|
|
init_kwargs = {
|
|
"model": model,
|
|
"callback": callback,
|
|
"url": url,
|
|
}
|
|
try:
|
|
return QwenTtsRealtime(api_key=api_key, **init_kwargs) # type: ignore[misc]
|
|
except TypeError as exc:
|
|
if "api_key" not in str(exc):
|
|
raise
|
|
return QwenTtsRealtime(**init_kwargs) # type: ignore[misc]
|
|
|
|
|
|
def _pcm16_to_wav_bytes(pcm_bytes: bytes, sample_rate: int = 24000) -> bytes:
|
|
with io.BytesIO() as buffer:
|
|
with wave.open(buffer, "wb") as wav_file:
|
|
wav_file.setnchannels(1)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setframerate(sample_rate)
|
|
wav_file.writeframes(pcm_bytes)
|
|
return buffer.getvalue()
|
|
|
|
|
|
def _synthesize_dashscope_preview(
|
|
*,
|
|
text: str,
|
|
api_key: str,
|
|
base_url: str,
|
|
model: str,
|
|
voice_key: str,
|
|
speed: Optional[float],
|
|
) -> bytes:
|
|
if not DASHSCOPE_SDK_AVAILABLE:
|
|
raise RuntimeError("dashscope package not installed; install with `pip install dashscope>=1.25.11`")
|
|
if not AudioFormat:
|
|
raise RuntimeError("DashScope SDK AudioFormat unavailable")
|
|
|
|
callback = _DashScopePreviewCallback()
|
|
if dashscope is not None:
|
|
dashscope.api_key = api_key
|
|
client = _create_dashscope_realtime_client(
|
|
model=model,
|
|
callback=callback,
|
|
url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
|
|
try:
|
|
client.connect()
|
|
callback.wait_for_open()
|
|
session_kwargs: Dict[str, Any] = {
|
|
"voice": voice_key,
|
|
"response_format": AudioFormat.PCM_24000HZ_MONO_16BIT,
|
|
"mode": "commit",
|
|
}
|
|
# speech_rate is supported by qwen3-* realtime models.
|
|
normalized_model = str(model or "").strip().lower()
|
|
if speed is not None and normalized_model.startswith("qwen3-"):
|
|
session_kwargs["speech_rate"] = max(0.5, min(2.0, float(speed)))
|
|
client.update_session(**session_kwargs)
|
|
client.append_text(text)
|
|
client.commit()
|
|
callback.wait_for_done()
|
|
callback.raise_if_error()
|
|
pcm_audio = callback.read_audio()
|
|
if not pcm_audio:
|
|
raise RuntimeError("No audio chunk returned from DashScope realtime synthesis")
|
|
return _pcm16_to_wav_bytes(pcm_audio, sample_rate=24000)
|
|
finally:
|
|
finish_fn = getattr(client, "finish", None)
|
|
if callable(finish_fn):
|
|
try:
|
|
finish_fn()
|
|
except Exception:
|
|
pass
|
|
close_fn = getattr(client, "close", None)
|
|
if callable(close_fn):
|
|
try:
|
|
close_fn()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _is_openai_compatible_vendor(vendor: str) -> bool:
|
|
normalized = (vendor or "").strip().lower()
|
|
return normalized in {
|
|
"openai compatible",
|
|
"openai-compatible",
|
|
"siliconflow", # backward compatibility
|
|
"硅基流动", # backward compatibility
|
|
}
|
|
|
|
|
|
def _is_dashscope_vendor(vendor: str) -> bool:
|
|
normalized = (vendor or "").strip().lower()
|
|
return normalized in {
|
|
"dashscope",
|
|
}
|
|
|
|
|
|
def _default_base_url(vendor: str) -> Optional[str]:
|
|
if _is_openai_compatible_vendor(vendor):
|
|
return "https://api.siliconflow.cn/v1"
|
|
if _is_dashscope_vendor(vendor):
|
|
return DASHSCOPE_DEFAULT_BASE_URL
|
|
return None
|
|
|
|
|
|
def _build_openai_compatible_voice_key(voice: Voice, model: str) -> str:
|
|
if voice.voice_key:
|
|
return voice.voice_key
|
|
if ":" in voice.id:
|
|
return voice.id
|
|
return f"{model}:{voice.id}"
|
|
|
|
|
|
@router.get("")
|
|
def list_voices(
|
|
vendor: Optional[str] = None,
|
|
language: Optional[str] = None,
|
|
gender: Optional[str] = None,
|
|
page: int = 1,
|
|
limit: int = 50,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取声音库列表"""
|
|
query = db.query(Voice)
|
|
if vendor:
|
|
query = query.filter(Voice.vendor == vendor)
|
|
if language:
|
|
query = query.filter(Voice.language == language)
|
|
if gender:
|
|
query = query.filter(Voice.gender == gender)
|
|
|
|
total = query.count()
|
|
voices = query.order_by(Voice.created_at.desc()) \
|
|
.offset((page - 1) * limit).limit(limit).all()
|
|
return {"total": total, "page": page, "limit": limit, "list": voices}
|
|
|
|
|
|
@router.post("", response_model=VoiceOut)
|
|
def create_voice(data: VoiceCreate, db: Session = Depends(get_db)):
|
|
"""创建声音"""
|
|
vendor = data.vendor.strip()
|
|
model = data.model
|
|
voice_key = data.voice_key
|
|
|
|
if _is_openai_compatible_vendor(vendor):
|
|
model = model or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
|
if not voice_key:
|
|
raw_id = (data.id or data.name).strip()
|
|
voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}"
|
|
elif _is_dashscope_vendor(vendor):
|
|
model = (model or "").strip() or DASHSCOPE_DEFAULT_MODEL
|
|
voice_key = (voice_key or "").strip() or DASHSCOPE_DEFAULT_VOICE_KEY
|
|
|
|
voice = Voice(
|
|
id=unique_short_id("tts", db, Voice),
|
|
user_id=1,
|
|
name=data.name,
|
|
vendor=vendor,
|
|
gender=data.gender,
|
|
language=data.language,
|
|
description=data.description,
|
|
model=model,
|
|
voice_key=voice_key,
|
|
api_key=data.api_key,
|
|
base_url=data.base_url,
|
|
speed=data.speed,
|
|
gain=data.gain,
|
|
pitch=data.pitch,
|
|
enabled=data.enabled,
|
|
)
|
|
db.add(voice)
|
|
db.commit()
|
|
db.refresh(voice)
|
|
return voice
|
|
|
|
|
|
@router.get("/{id}", response_model=VoiceOut)
|
|
def get_voice(id: str, db: Session = Depends(get_db)):
|
|
"""获取单个声音详情"""
|
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
if not voice:
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
return voice
|
|
|
|
|
|
@router.put("/{id}", response_model=VoiceOut)
|
|
def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)):
|
|
"""更新声音"""
|
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
if not voice:
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
if "vendor" in update_data and update_data["vendor"] is not None:
|
|
update_data["vendor"] = update_data["vendor"].strip()
|
|
|
|
vendor_for_defaults = update_data.get("vendor", voice.vendor)
|
|
if _is_openai_compatible_vendor(vendor_for_defaults):
|
|
model = update_data.get("model") or voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
|
voice_key = update_data.get("voice_key") or voice.voice_key
|
|
update_data["model"] = model
|
|
update_data["voice_key"] = voice_key or _build_openai_compatible_voice_key(voice, model)
|
|
elif _is_dashscope_vendor(vendor_for_defaults):
|
|
model = update_data.get("model") or voice.model or DASHSCOPE_DEFAULT_MODEL
|
|
voice_key = update_data.get("voice_key") or voice.voice_key or DASHSCOPE_DEFAULT_VOICE_KEY
|
|
update_data["model"] = model
|
|
update_data["voice_key"] = voice_key
|
|
|
|
for field, value in update_data.items():
|
|
setattr(voice, field, value)
|
|
|
|
db.commit()
|
|
db.refresh(voice)
|
|
return voice
|
|
|
|
|
|
@router.delete("/{id}")
|
|
def delete_voice(id: str, db: Session = Depends(get_db)):
|
|
"""删除声音"""
|
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
if not voice:
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
db.delete(voice)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
@router.post("/{id}/preview", response_model=VoicePreviewResponse)
|
|
def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_db)):
|
|
"""试听指定声音,支持 OpenAI-compatible 与 DashScope Realtime。"""
|
|
voice = db.query(Voice).filter(Voice.id == id).first()
|
|
if not voice:
|
|
raise HTTPException(status_code=404, detail="Voice not found")
|
|
|
|
text = data.text.strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="Preview text cannot be empty")
|
|
|
|
if _is_dashscope_vendor(voice.vendor):
|
|
api_key = (data.api_key or "").strip() or (voice.api_key or "").strip()
|
|
if not api_key:
|
|
api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("TTS_API_KEY", "").strip()
|
|
if not api_key:
|
|
raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}")
|
|
|
|
base_url = (voice.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL
|
|
model = (voice.model or "").strip() or DASHSCOPE_DEFAULT_MODEL
|
|
voice_key = (voice.voice_key or "").strip() or DASHSCOPE_DEFAULT_VOICE_KEY
|
|
effective_speed = data.speed if data.speed is not None else voice.speed
|
|
try:
|
|
wav_bytes = _synthesize_dashscope_preview(
|
|
text=text,
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
model=model,
|
|
voice_key=voice_key,
|
|
speed=effective_speed,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=502, detail=f"DashScope preview failed: {exc}") from exc
|
|
audio_base64 = base64.b64encode(wav_bytes).decode("utf-8")
|
|
return VoicePreviewResponse(success=True, audio_url=f"data:audio/wav;base64,{audio_base64}")
|
|
|
|
api_key = (data.api_key or "").strip() or (voice.api_key or "").strip()
|
|
if not api_key and _is_openai_compatible_vendor(voice.vendor):
|
|
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
|
if not api_key:
|
|
raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}")
|
|
|
|
base_url = (voice.base_url or "").strip() or (_default_base_url(voice.vendor) or "")
|
|
if not base_url:
|
|
raise HTTPException(status_code=400, detail=f"Base URL is required for voice: {voice.name}")
|
|
|
|
model = voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
|
payload = {
|
|
"model": model,
|
|
"input": text,
|
|
"voice": voice.voice_key or _build_openai_compatible_voice_key(voice, model),
|
|
"response_format": "mp3",
|
|
"speed": data.speed if data.speed is not None else voice.speed,
|
|
}
|
|
|
|
try:
|
|
with httpx.Client(timeout=45.0) as client:
|
|
response = client.post(
|
|
f"{base_url.rstrip('/')}/audio/speech",
|
|
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
|
json=payload,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=502, detail=f"TTS request failed: {exc}") from exc
|
|
|
|
if response.status_code != 200:
|
|
detail = response.text
|
|
try:
|
|
detail_json = response.json()
|
|
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}")
|
|
|
|
audio_base64 = base64.b64encode(response.content).decode("utf-8")
|
|
return VoicePreviewResponse(success=True, audio_url=f"data:audio/mpeg;base64,{audio_base64}")
|