import base64 import io import json import os import threading import wave from typing import Any, Dict, Optional import httpx from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from ..db import get_db from ..id_generator import unique_short_id from ..models import Voice from ..schemas import VoiceCreate, VoiceOut, VoicePreviewRequest, VoicePreviewResponse, VoiceUpdate router = APIRouter(prefix="/voices", tags=["Voices"]) OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B" DASHSCOPE_DEFAULT_MODEL = "qwen3-tts-flash-realtime" DASHSCOPE_DEFAULT_VOICE_KEY = "Cherry" DASHSCOPE_DEFAULT_BASE_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" try: import dashscope from dashscope.audio.qwen_tts_realtime import AudioFormat, QwenTtsRealtime, QwenTtsRealtimeCallback DASHSCOPE_SDK_AVAILABLE = True except ImportError: dashscope = None # type: ignore[assignment] AudioFormat = None # type: ignore[assignment] QwenTtsRealtime = None # type: ignore[assignment] DASHSCOPE_SDK_AVAILABLE = False class QwenTtsRealtimeCallback: # type: ignore[no-redef] """Fallback callback base when DashScope SDK is unavailable.""" pass class _DashScopePreviewCallback(QwenTtsRealtimeCallback): """Collect DashScope realtime callback events and PCM chunks.""" def __init__(self) -> None: super().__init__() self._open_event = threading.Event() self._done_event = threading.Event() self._lock = threading.Lock() self._audio_chunks: list[bytes] = [] self._error_message: Optional[str] = None def on_open(self) -> None: self._open_event.set() def on_close(self, code: int, reason: str) -> None: if not self._done_event.is_set(): self._error_message = f"DashScope websocket closed unexpectedly: {code} {reason}" self._done_event.set() def on_error(self, message: str) -> None: self._error_message = str(message) self._done_event.set() def on_event(self, response: Any) -> None: payload = _coerce_dashscope_event(response) event_type = str(payload.get("type") or "").strip() if event_type == "response.audio.delta": delta = payload.get("delta") if isinstance(delta, str): try: self._append_audio(base64.b64decode(delta)) except Exception: return elif event_type in {"response.done", "session.finished"}: self._done_event.set() elif event_type == "error": self._error_message = _format_dashscope_error_event(payload) self._done_event.set() def on_data(self, data: bytes) -> None: # Some SDK versions emit raw PCM frames via on_data. if isinstance(data, (bytes, bytearray)): self._append_audio(bytes(data)) def wait_for_open(self, timeout: float = 10.0) -> None: if not self._open_event.wait(timeout): raise TimeoutError("DashScope websocket open timeout") def wait_for_done(self, timeout: float = 45.0) -> None: if not self._done_event.wait(timeout): raise TimeoutError("DashScope synthesis timeout") def raise_if_error(self) -> None: if self._error_message: raise RuntimeError(self._error_message) def read_audio(self) -> bytes: with self._lock: return b"".join(self._audio_chunks) def _append_audio(self, chunk: bytes) -> None: if not chunk: return with self._lock: self._audio_chunks.append(chunk) def _coerce_dashscope_event(response: Any) -> Dict[str, Any]: if isinstance(response, dict): return response if isinstance(response, str): try: parsed = json.loads(response) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: pass return {"type": "raw", "message": str(response)} def _format_dashscope_error_event(payload: Dict[str, Any]) -> str: error = payload.get("error") if isinstance(error, dict): code = str(error.get("code") or "").strip() message = str(error.get("message") or "").strip() if code and message: return f"{code}: {message}" return message or str(error) return str(error or "DashScope realtime TTS error") def _create_dashscope_realtime_client(*, model: str, callback: _DashScopePreviewCallback, url: str, api_key: str) -> Any: if QwenTtsRealtime is None: raise RuntimeError("DashScope SDK unavailable") init_kwargs = { "model": model, "callback": callback, "url": url, } try: return QwenTtsRealtime(api_key=api_key, **init_kwargs) # type: ignore[misc] except TypeError as exc: if "api_key" not in str(exc): raise return QwenTtsRealtime(**init_kwargs) # type: ignore[misc] def _pcm16_to_wav_bytes(pcm_bytes: bytes, sample_rate: int = 24000) -> bytes: with io.BytesIO() as buffer: with wave.open(buffer, "wb") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) wav_file.writeframes(pcm_bytes) return buffer.getvalue() def _synthesize_dashscope_preview( *, text: str, api_key: str, base_url: str, model: str, voice_key: str, speed: Optional[float], ) -> bytes: if not DASHSCOPE_SDK_AVAILABLE: raise RuntimeError("dashscope package not installed; install with `pip install dashscope>=1.25.11`") if not AudioFormat: raise RuntimeError("DashScope SDK AudioFormat unavailable") callback = _DashScopePreviewCallback() if dashscope is not None: dashscope.api_key = api_key client = _create_dashscope_realtime_client( model=model, callback=callback, url=base_url, api_key=api_key, ) try: client.connect() callback.wait_for_open() session_kwargs: Dict[str, Any] = { "voice": voice_key, "response_format": AudioFormat.PCM_24000HZ_MONO_16BIT, "mode": "commit", } # speech_rate is supported by qwen3-* realtime models. normalized_model = str(model or "").strip().lower() if speed is not None and normalized_model.startswith("qwen3-"): session_kwargs["speech_rate"] = max(0.5, min(2.0, float(speed))) client.update_session(**session_kwargs) client.append_text(text) client.commit() callback.wait_for_done() callback.raise_if_error() pcm_audio = callback.read_audio() if not pcm_audio: raise RuntimeError("No audio chunk returned from DashScope realtime synthesis") return _pcm16_to_wav_bytes(pcm_audio, sample_rate=24000) finally: finish_fn = getattr(client, "finish", None) if callable(finish_fn): try: finish_fn() except Exception: pass close_fn = getattr(client, "close", None) if callable(close_fn): try: close_fn() except Exception: pass def _is_openai_compatible_vendor(vendor: str) -> bool: normalized = (vendor or "").strip().lower() return normalized in { "openai compatible", "openai-compatible", "siliconflow", # backward compatibility "硅基流动", # backward compatibility } def _is_dashscope_vendor(vendor: str) -> bool: normalized = (vendor or "").strip().lower() return normalized in { "dashscope", } def _default_base_url(vendor: str) -> Optional[str]: if _is_openai_compatible_vendor(vendor): return "https://api.siliconflow.cn/v1" if _is_dashscope_vendor(vendor): return DASHSCOPE_DEFAULT_BASE_URL return None def _build_openai_compatible_voice_key(voice: Voice, model: str) -> str: if voice.voice_key: return voice.voice_key if ":" in voice.id: return voice.id return f"{model}:{voice.id}" @router.get("") def list_voices( vendor: Optional[str] = None, language: Optional[str] = None, gender: Optional[str] = None, page: int = 1, limit: int = 50, db: Session = Depends(get_db) ): """获取声音库列表""" query = db.query(Voice) if vendor: query = query.filter(Voice.vendor == vendor) if language: query = query.filter(Voice.language == language) if gender: query = query.filter(Voice.gender == gender) total = query.count() voices = query.order_by(Voice.created_at.desc()) \ .offset((page - 1) * limit).limit(limit).all() return {"total": total, "page": page, "limit": limit, "list": voices} @router.post("", response_model=VoiceOut) def create_voice(data: VoiceCreate, db: Session = Depends(get_db)): """创建声音""" vendor = data.vendor.strip() model = data.model voice_key = data.voice_key if _is_openai_compatible_vendor(vendor): model = model or OPENAI_COMPATIBLE_DEFAULT_MODEL if not voice_key: raw_id = (data.id or data.name).strip() voice_key = raw_id if ":" in raw_id else f"{model}:{raw_id}" elif _is_dashscope_vendor(vendor): model = (model or "").strip() or DASHSCOPE_DEFAULT_MODEL voice_key = (voice_key or "").strip() or DASHSCOPE_DEFAULT_VOICE_KEY voice = Voice( id=unique_short_id("tts", db, Voice), user_id=1, name=data.name, vendor=vendor, gender=data.gender, language=data.language, description=data.description, model=model, voice_key=voice_key, api_key=data.api_key, base_url=data.base_url, speed=data.speed, gain=data.gain, pitch=data.pitch, enabled=data.enabled, ) db.add(voice) db.commit() db.refresh(voice) return voice @router.get("/{id}", response_model=VoiceOut) def get_voice(id: str, db: Session = Depends(get_db)): """获取单个声音详情""" voice = db.query(Voice).filter(Voice.id == id).first() if not voice: raise HTTPException(status_code=404, detail="Voice not found") return voice @router.put("/{id}", response_model=VoiceOut) def update_voice(id: str, data: VoiceUpdate, db: Session = Depends(get_db)): """更新声音""" voice = db.query(Voice).filter(Voice.id == id).first() if not voice: raise HTTPException(status_code=404, detail="Voice not found") update_data = data.model_dump(exclude_unset=True) if "vendor" in update_data and update_data["vendor"] is not None: update_data["vendor"] = update_data["vendor"].strip() vendor_for_defaults = update_data.get("vendor", voice.vendor) if _is_openai_compatible_vendor(vendor_for_defaults): model = update_data.get("model") or voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL voice_key = update_data.get("voice_key") or voice.voice_key update_data["model"] = model update_data["voice_key"] = voice_key or _build_openai_compatible_voice_key(voice, model) elif _is_dashscope_vendor(vendor_for_defaults): model = update_data.get("model") or voice.model or DASHSCOPE_DEFAULT_MODEL voice_key = update_data.get("voice_key") or voice.voice_key or DASHSCOPE_DEFAULT_VOICE_KEY update_data["model"] = model update_data["voice_key"] = voice_key for field, value in update_data.items(): setattr(voice, field, value) db.commit() db.refresh(voice) return voice @router.delete("/{id}") def delete_voice(id: str, db: Session = Depends(get_db)): """删除声音""" voice = db.query(Voice).filter(Voice.id == id).first() if not voice: raise HTTPException(status_code=404, detail="Voice not found") db.delete(voice) db.commit() return {"message": "Deleted successfully"} @router.post("/{id}/preview", response_model=VoicePreviewResponse) def preview_voice(id: str, data: VoicePreviewRequest, db: Session = Depends(get_db)): """试听指定声音,支持 OpenAI-compatible 与 DashScope Realtime。""" voice = db.query(Voice).filter(Voice.id == id).first() if not voice: raise HTTPException(status_code=404, detail="Voice not found") text = data.text.strip() if not text: raise HTTPException(status_code=400, detail="Preview text cannot be empty") if _is_dashscope_vendor(voice.vendor): api_key = (data.api_key or "").strip() or (voice.api_key or "").strip() if not api_key: api_key = os.getenv("DASHSCOPE_API_KEY", "").strip() or os.getenv("TTS_API_KEY", "").strip() if not api_key: raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}") base_url = (voice.base_url or "").strip() or DASHSCOPE_DEFAULT_BASE_URL model = (voice.model or "").strip() or DASHSCOPE_DEFAULT_MODEL voice_key = (voice.voice_key or "").strip() or DASHSCOPE_DEFAULT_VOICE_KEY effective_speed = data.speed if data.speed is not None else voice.speed try: wav_bytes = _synthesize_dashscope_preview( text=text, api_key=api_key, base_url=base_url, model=model, voice_key=voice_key, speed=effective_speed, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"DashScope preview failed: {exc}") from exc audio_base64 = base64.b64encode(wav_bytes).decode("utf-8") return VoicePreviewResponse(success=True, audio_url=f"data:audio/wav;base64,{audio_base64}") api_key = (data.api_key or "").strip() or (voice.api_key or "").strip() if not api_key and _is_openai_compatible_vendor(voice.vendor): api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() if not api_key: raise HTTPException(status_code=400, detail=f"API key is required for voice: {voice.name}") base_url = (voice.base_url or "").strip() or (_default_base_url(voice.vendor) or "") if not base_url: raise HTTPException(status_code=400, detail=f"Base URL is required for voice: {voice.name}") model = voice.model or OPENAI_COMPATIBLE_DEFAULT_MODEL payload = { "model": model, "input": text, "voice": voice.voice_key or _build_openai_compatible_voice_key(voice, model), "response_format": "mp3", "speed": data.speed if data.speed is not None else voice.speed, } try: with httpx.Client(timeout=45.0) as client: response = client.post( f"{base_url.rstrip('/')}/audio/speech", headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, json=payload, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"TTS request failed: {exc}") from exc if response.status_code != 200: detail = response.text try: detail_json = response.json() detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail except Exception: pass raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}") audio_base64 = base64.b64encode(response.content).decode("utf-8") return VoicePreviewResponse(success=True, audio_url=f"data:audio/mpeg;base64,{audio_base64}")