Add opener audio functionality to Assistant model and related schemas, enabling audio generation and playback features. Update API routes and frontend components to support opener audio management, including status retrieval and generation controls.

This commit is contained in:
Xin Wang
2026-02-26 14:31:50 +08:00
parent 833cb0d4c4
commit fb95e2abe2
9 changed files with 551 additions and 4 deletions

View File

@@ -1,18 +1,33 @@
import audioop
import hashlib
import io
import os
import wave
from pathlib import Path
import httpx
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from typing import Any, Dict, List, Optional
import uuid
from datetime import datetime
from ..db import get_db
from ..models import Assistant, LLMModel, ASRModel, Voice
from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice
from ..schemas import (
AssistantCreate, AssistantUpdate, AssistantOut, AssistantEngineConfigResponse
AssistantCreate,
AssistantUpdate,
AssistantOut,
AssistantEngineConfigResponse,
AssistantOpenerAudioGenerateRequest,
AssistantOpenerAudioOut,
)
router = APIRouter(prefix="/assistants", tags=["Assistants"])
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
OPENAI_COMPATIBLE_DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1"
OPENER_AUDIO_DIR = Path(__file__).resolve().parents[2] / "data" / "opener_audio"
OPENAI_COMPATIBLE_KNOWN_VOICES = {
"alex",
"anna",
@@ -163,6 +178,19 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
"kbId": assistant.knowledge_base_id,
"nResults": 5,
}
opener_audio = assistant.opener_audio
opener_audio_ready = bool(opener_audio and opener_audio.file_path and Path(opener_audio.file_path).exists())
metadata["openerAudio"] = {
"enabled": bool(opener_audio.enabled) if opener_audio else False,
"ready": opener_audio_ready,
"encoding": opener_audio.encoding if opener_audio else "pcm_s16le",
"sampleRateHz": int(opener_audio.sample_rate_hz) if opener_audio else 16000,
"channels": int(opener_audio.channels) if opener_audio else 1,
"durationMs": int(opener_audio.duration_ms) if opener_audio else 0,
"textHash": opener_audio.text_hash if opener_audio else None,
"ttsFingerprint": opener_audio.tts_fingerprint if opener_audio else None,
"pcmUrl": f"/api/assistants/{assistant.id}/opener-audio/pcm" if opener_audio_ready else None,
}
return metadata, warnings
@@ -189,6 +217,8 @@ def _build_engine_assistant_config(db: Session, assistant: Assistant) -> Dict[st
def assistant_to_dict(assistant: Assistant) -> dict:
opener_audio = assistant.opener_audio
opener_audio_ready = bool(opener_audio and opener_audio.file_path and Path(opener_audio.file_path).exists())
return {
"id": assistant.id,
"name": assistant.name,
@@ -196,6 +226,10 @@ def assistant_to_dict(assistant: Assistant) -> dict:
"firstTurnMode": assistant.first_turn_mode or "bot_first",
"opener": assistant.opener or "",
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
"openerAudioEnabled": bool(opener_audio.enabled) if opener_audio else False,
"openerAudioReady": opener_audio_ready,
"openerAudioDurationMs": int(opener_audio.duration_ms) if opener_audio else 0,
"openerAudioUpdatedAt": opener_audio.updated_at if opener_audio else None,
"prompt": assistant.prompt or "",
"knowledgeBaseId": assistant.knowledge_base_id,
"language": assistant.language,
@@ -238,6 +272,114 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
setattr(assistant, field_map.get(field, field), value)
def _ensure_assistant_opener_audio(db: Session, assistant: Assistant) -> AssistantOpenerAudio:
record = assistant.opener_audio
if record:
return record
record = AssistantOpenerAudio(assistant_id=assistant.id, enabled=False)
db.add(record)
db.flush()
return record
def _resolve_tts_runtime_for_assistant(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], Optional[Voice]]:
metadata, _ = _resolve_runtime_metadata(db, assistant)
services = metadata.get("services") if isinstance(metadata.get("services"), dict) else {}
tts = services.get("tts") if isinstance(services, dict) and isinstance(services.get("tts"), dict) else {}
voice = db.query(Voice).filter(Voice.id == assistant.voice).first() if assistant.voice else None
return tts, voice
def _tts_fingerprint(tts_cfg: Dict[str, Any], opener_text: str) -> str:
identity = {
"provider": tts_cfg.get("provider"),
"model": tts_cfg.get("model"),
"voice": tts_cfg.get("voice"),
"speed": tts_cfg.get("speed"),
"text": opener_text,
}
return hashlib.sha256(str(identity).encode("utf-8")).hexdigest()
def _synthesize_openai_compatible_wav(
*,
text: str,
model: str,
voice_key: str,
speed: float,
api_key: str,
base_url: str,
) -> bytes:
payload = {
"model": model or OPENAI_COMPATIBLE_DEFAULT_MODEL,
"input": text,
"voice": voice_key,
"response_format": "wav",
"speed": speed,
}
with httpx.Client(timeout=45.0) as client:
response = client.post(
f"{base_url.rstrip('/')}/audio/speech",
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
json=payload,
)
if response.status_code != 200:
detail = response.text
try:
detail_json = response.json()
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
except Exception:
pass
raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}")
return response.content
def _wav_to_pcm16_mono_16k(wav_bytes: bytes) -> tuple[bytes, int]:
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
sample_rate = wav_file.getframerate()
frames = wav_file.getnframes()
raw = wav_file.readframes(frames)
if sample_width != 2:
raise HTTPException(status_code=400, detail=f"Unsupported WAV sample width: {sample_width * 8}bit")
if channels > 1:
raw = audioop.tomono(raw, sample_width, 0.5, 0.5)
if sample_rate != 16000:
raw, _ = audioop.ratecv(raw, sample_width, 1, sample_rate, 16000, None)
duration_ms = int((len(raw) / (16000 * 2)) * 1000)
return raw, duration_ms
def _persist_opener_audio_pcm(assistant_id: str, pcm_bytes: bytes) -> str:
OPENER_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
file_path = OPENER_AUDIO_DIR / f"{assistant_id}.pcm"
with open(file_path, "wb") as f:
f.write(pcm_bytes)
return str(file_path)
def _opener_audio_out(record: Optional[AssistantOpenerAudio]) -> AssistantOpenerAudioOut:
if not record:
return AssistantOpenerAudioOut()
ready = bool(record.file_path and Path(record.file_path).exists())
return AssistantOpenerAudioOut(
enabled=bool(record.enabled),
ready=ready,
encoding=record.encoding,
sample_rate_hz=record.sample_rate_hz,
channels=record.channels,
duration_ms=record.duration_ms,
updated_at=record.updated_at,
text_hash=record.text_hash,
tts_fingerprint=record.tts_fingerprint,
)
# ============ Assistants ============
@router.get("")
def list_assistants(
@@ -316,9 +458,132 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
db.add(assistant)
db.commit()
db.refresh(assistant)
opener_audio = _ensure_assistant_opener_audio(db, assistant)
opener_audio.enabled = bool(data.openerAudioEnabled)
opener_audio.updated_at = datetime.utcnow()
db.commit()
db.refresh(assistant)
return assistant_to_dict(assistant)
@router.get("/{id}/opener-audio", response_model=AssistantOpenerAudioOut)
def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return _opener_audio_out(assistant.opener_audio)
@router.get("/{id}/opener-audio/pcm")
def get_assistant_opener_audio_pcm(id: str, db: Session = Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
record = assistant.opener_audio
if not record or not record.file_path:
raise HTTPException(status_code=404, detail="Opener audio not generated")
file_path = Path(record.file_path)
if not file_path.exists():
raise HTTPException(status_code=404, detail="Opener audio file missing")
return FileResponse(
str(file_path),
media_type="application/octet-stream",
filename=f"{assistant.id}.pcm",
)
@router.post("/{id}/opener-audio/generate", response_model=AssistantOpenerAudioOut)
def generate_assistant_opener_audio(
id: str,
data: AssistantOpenerAudioGenerateRequest,
db: Session = Depends(get_db),
):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
if not assistant.voice_output_enabled:
raise HTTPException(status_code=400, detail="Voice output is disabled")
opener_text = (data.text if data.text is not None else assistant.opener or "").strip()
if not opener_text:
raise HTTPException(status_code=400, detail="Opener text is empty")
tts_cfg, voice = _resolve_tts_runtime_for_assistant(db, assistant)
provider = str(tts_cfg.get("provider") or "").strip().lower()
if provider not in {"openai_compatible", "dashscope"}:
raise HTTPException(status_code=400, detail=f"Unsupported provider for preloaded opener audio: {provider or 'unknown'}")
speed = float(tts_cfg.get("speed") or assistant.speed or 1.0)
voice_key = str(tts_cfg.get("voice") or "").strip()
model = str(tts_cfg.get("model") or "").strip() or OPENAI_COMPATIBLE_DEFAULT_MODEL
api_key = str(tts_cfg.get("apiKey") or "").strip()
base_url = str(tts_cfg.get("baseUrl") or "").strip()
if provider == "openai_compatible":
if not api_key:
if voice and voice.api_key:
api_key = voice.api_key.strip()
if not api_key:
api_key = (os.getenv("SILICONFLOW_API_KEY", "") or os.getenv("TTS_API_KEY", "")).strip()
if not api_key:
raise HTTPException(status_code=400, detail="TTS API key is missing")
if not base_url:
base_url = OPENAI_COMPATIBLE_DEFAULT_BASE_URL
wav_bytes = _synthesize_openai_compatible_wav(
text=opener_text,
model=model,
voice_key=voice_key,
speed=speed,
api_key=api_key,
base_url=base_url,
)
else:
from .voices import _synthesize_dashscope_preview, DASHSCOPE_DEFAULT_BASE_URL, DASHSCOPE_DEFAULT_MODEL, DASHSCOPE_DEFAULT_VOICE_KEY
if not api_key:
if voice and voice.api_key:
api_key = voice.api_key.strip()
if not api_key:
api_key = (os.getenv("DASHSCOPE_API_KEY", "") or os.getenv("TTS_API_KEY", "")).strip()
if not api_key:
raise HTTPException(status_code=400, detail="DashScope API key is missing")
if not base_url:
base_url = DASHSCOPE_DEFAULT_BASE_URL
if not model:
model = DASHSCOPE_DEFAULT_MODEL
if not voice_key:
voice_key = DASHSCOPE_DEFAULT_VOICE_KEY
try:
wav_bytes = _synthesize_dashscope_preview(
text=opener_text,
api_key=api_key,
base_url=base_url,
model=model,
voice_key=voice_key,
speed=speed,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"DashScope opener audio generation failed: {exc}") from exc
pcm_bytes, duration_ms = _wav_to_pcm16_mono_16k(wav_bytes)
record = _ensure_assistant_opener_audio(db, assistant)
record.enabled = True
record.file_path = _persist_opener_audio_pcm(assistant.id, pcm_bytes)
record.encoding = "pcm_s16le"
record.sample_rate_hz = 16000
record.channels = 1
record.duration_ms = duration_ms
record.text_hash = hashlib.sha256(opener_text.encode("utf-8")).hexdigest()
record.tts_fingerprint = _tts_fingerprint(tts_cfg, opener_text)
now = datetime.utcnow()
if not record.created_at:
record.created_at = now
record.updated_at = now
assistant.updated_at = now
db.commit()
db.refresh(assistant)
return _opener_audio_out(assistant.opener_audio)
@router.put("/{id}")
def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_db)):
"""更新助手"""
@@ -327,7 +592,12 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
raise HTTPException(status_code=404, detail="Assistant not found")
update_data = data.model_dump(exclude_unset=True)
opener_audio_enabled = update_data.pop("openerAudioEnabled", None)
_apply_assistant_update(assistant, update_data)
if opener_audio_enabled is not None:
record = _ensure_assistant_opener_audio(db, assistant)
record.enabled = bool(opener_audio_enabled)
record.updated_at = datetime.utcnow()
assistant.updated_at = datetime.utcnow()
db.commit()