Add opener audio functionality to Assistant model and related schemas, enabling audio generation and playback features. Update API routes and frontend components to support opener audio management, including status retrieval and generation controls.
This commit is contained in:
@@ -1,18 +1,33 @@
|
||||
import audioop
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import Assistant, LLMModel, ASRModel, Voice
|
||||
from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice
|
||||
from ..schemas import (
|
||||
AssistantCreate, AssistantUpdate, AssistantOut, AssistantEngineConfigResponse
|
||||
AssistantCreate,
|
||||
AssistantUpdate,
|
||||
AssistantOut,
|
||||
AssistantEngineConfigResponse,
|
||||
AssistantOpenerAudioGenerateRequest,
|
||||
AssistantOpenerAudioOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/assistants", tags=["Assistants"])
|
||||
|
||||
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
||||
OPENAI_COMPATIBLE_DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1"
|
||||
OPENER_AUDIO_DIR = Path(__file__).resolve().parents[2] / "data" / "opener_audio"
|
||||
OPENAI_COMPATIBLE_KNOWN_VOICES = {
|
||||
"alex",
|
||||
"anna",
|
||||
@@ -163,6 +178,19 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
|
||||
"kbId": assistant.knowledge_base_id,
|
||||
"nResults": 5,
|
||||
}
|
||||
opener_audio = assistant.opener_audio
|
||||
opener_audio_ready = bool(opener_audio and opener_audio.file_path and Path(opener_audio.file_path).exists())
|
||||
metadata["openerAudio"] = {
|
||||
"enabled": bool(opener_audio.enabled) if opener_audio else False,
|
||||
"ready": opener_audio_ready,
|
||||
"encoding": opener_audio.encoding if opener_audio else "pcm_s16le",
|
||||
"sampleRateHz": int(opener_audio.sample_rate_hz) if opener_audio else 16000,
|
||||
"channels": int(opener_audio.channels) if opener_audio else 1,
|
||||
"durationMs": int(opener_audio.duration_ms) if opener_audio else 0,
|
||||
"textHash": opener_audio.text_hash if opener_audio else None,
|
||||
"ttsFingerprint": opener_audio.tts_fingerprint if opener_audio else None,
|
||||
"pcmUrl": f"/api/assistants/{assistant.id}/opener-audio/pcm" if opener_audio_ready else None,
|
||||
}
|
||||
return metadata, warnings
|
||||
|
||||
|
||||
@@ -189,6 +217,8 @@ def _build_engine_assistant_config(db: Session, assistant: Assistant) -> Dict[st
|
||||
|
||||
|
||||
def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
opener_audio = assistant.opener_audio
|
||||
opener_audio_ready = bool(opener_audio and opener_audio.file_path and Path(opener_audio.file_path).exists())
|
||||
return {
|
||||
"id": assistant.id,
|
||||
"name": assistant.name,
|
||||
@@ -196,6 +226,10 @@ def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
"firstTurnMode": assistant.first_turn_mode or "bot_first",
|
||||
"opener": assistant.opener or "",
|
||||
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
|
||||
"openerAudioEnabled": bool(opener_audio.enabled) if opener_audio else False,
|
||||
"openerAudioReady": opener_audio_ready,
|
||||
"openerAudioDurationMs": int(opener_audio.duration_ms) if opener_audio else 0,
|
||||
"openerAudioUpdatedAt": opener_audio.updated_at if opener_audio else None,
|
||||
"prompt": assistant.prompt or "",
|
||||
"knowledgeBaseId": assistant.knowledge_base_id,
|
||||
"language": assistant.language,
|
||||
@@ -238,6 +272,114 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
||||
setattr(assistant, field_map.get(field, field), value)
|
||||
|
||||
|
||||
def _ensure_assistant_opener_audio(db: Session, assistant: Assistant) -> AssistantOpenerAudio:
|
||||
record = assistant.opener_audio
|
||||
if record:
|
||||
return record
|
||||
record = AssistantOpenerAudio(assistant_id=assistant.id, enabled=False)
|
||||
db.add(record)
|
||||
db.flush()
|
||||
return record
|
||||
|
||||
|
||||
def _resolve_tts_runtime_for_assistant(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], Optional[Voice]]:
|
||||
metadata, _ = _resolve_runtime_metadata(db, assistant)
|
||||
services = metadata.get("services") if isinstance(metadata.get("services"), dict) else {}
|
||||
tts = services.get("tts") if isinstance(services, dict) and isinstance(services.get("tts"), dict) else {}
|
||||
voice = db.query(Voice).filter(Voice.id == assistant.voice).first() if assistant.voice else None
|
||||
return tts, voice
|
||||
|
||||
|
||||
def _tts_fingerprint(tts_cfg: Dict[str, Any], opener_text: str) -> str:
|
||||
identity = {
|
||||
"provider": tts_cfg.get("provider"),
|
||||
"model": tts_cfg.get("model"),
|
||||
"voice": tts_cfg.get("voice"),
|
||||
"speed": tts_cfg.get("speed"),
|
||||
"text": opener_text,
|
||||
}
|
||||
return hashlib.sha256(str(identity).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _synthesize_openai_compatible_wav(
|
||||
*,
|
||||
text: str,
|
||||
model: str,
|
||||
voice_key: str,
|
||||
speed: float,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
) -> bytes:
|
||||
payload = {
|
||||
"model": model or OPENAI_COMPATIBLE_DEFAULT_MODEL,
|
||||
"input": text,
|
||||
"voice": voice_key,
|
||||
"response_format": "wav",
|
||||
"speed": speed,
|
||||
}
|
||||
with httpx.Client(timeout=45.0) as client:
|
||||
response = client.post(
|
||||
f"{base_url.rstrip('/')}/audio/speech",
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
detail = response.text
|
||||
try:
|
||||
detail_json = response.json()
|
||||
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}")
|
||||
return response.content
|
||||
|
||||
|
||||
def _wav_to_pcm16_mono_16k(wav_bytes: bytes) -> tuple[bytes, int]:
|
||||
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
|
||||
channels = wav_file.getnchannels()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_rate = wav_file.getframerate()
|
||||
frames = wav_file.getnframes()
|
||||
raw = wav_file.readframes(frames)
|
||||
|
||||
if sample_width != 2:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported WAV sample width: {sample_width * 8}bit")
|
||||
|
||||
if channels > 1:
|
||||
raw = audioop.tomono(raw, sample_width, 0.5, 0.5)
|
||||
|
||||
if sample_rate != 16000:
|
||||
raw, _ = audioop.ratecv(raw, sample_width, 1, sample_rate, 16000, None)
|
||||
|
||||
duration_ms = int((len(raw) / (16000 * 2)) * 1000)
|
||||
return raw, duration_ms
|
||||
|
||||
|
||||
def _persist_opener_audio_pcm(assistant_id: str, pcm_bytes: bytes) -> str:
|
||||
OPENER_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = OPENER_AUDIO_DIR / f"{assistant_id}.pcm"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(pcm_bytes)
|
||||
return str(file_path)
|
||||
|
||||
|
||||
def _opener_audio_out(record: Optional[AssistantOpenerAudio]) -> AssistantOpenerAudioOut:
|
||||
if not record:
|
||||
return AssistantOpenerAudioOut()
|
||||
ready = bool(record.file_path and Path(record.file_path).exists())
|
||||
return AssistantOpenerAudioOut(
|
||||
enabled=bool(record.enabled),
|
||||
ready=ready,
|
||||
encoding=record.encoding,
|
||||
sample_rate_hz=record.sample_rate_hz,
|
||||
channels=record.channels,
|
||||
duration_ms=record.duration_ms,
|
||||
updated_at=record.updated_at,
|
||||
text_hash=record.text_hash,
|
||||
tts_fingerprint=record.tts_fingerprint,
|
||||
)
|
||||
|
||||
|
||||
# ============ Assistants ============
|
||||
@router.get("")
|
||||
def list_assistants(
|
||||
@@ -316,9 +458,132 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
db.add(assistant)
|
||||
db.commit()
|
||||
db.refresh(assistant)
|
||||
opener_audio = _ensure_assistant_opener_audio(db, assistant)
|
||||
opener_audio.enabled = bool(data.openerAudioEnabled)
|
||||
opener_audio.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(assistant)
|
||||
return assistant_to_dict(assistant)
|
||||
|
||||
|
||||
@router.get("/{id}/opener-audio", response_model=AssistantOpenerAudioOut)
|
||||
def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return _opener_audio_out(assistant.opener_audio)
|
||||
|
||||
|
||||
@router.get("/{id}/opener-audio/pcm")
|
||||
def get_assistant_opener_audio_pcm(id: str, db: Session = Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
record = assistant.opener_audio
|
||||
if not record or not record.file_path:
|
||||
raise HTTPException(status_code=404, detail="Opener audio not generated")
|
||||
file_path = Path(record.file_path)
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Opener audio file missing")
|
||||
return FileResponse(
|
||||
str(file_path),
|
||||
media_type="application/octet-stream",
|
||||
filename=f"{assistant.id}.pcm",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/opener-audio/generate", response_model=AssistantOpenerAudioOut)
|
||||
def generate_assistant_opener_audio(
|
||||
id: str,
|
||||
data: AssistantOpenerAudioGenerateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
if not assistant.voice_output_enabled:
|
||||
raise HTTPException(status_code=400, detail="Voice output is disabled")
|
||||
|
||||
opener_text = (data.text if data.text is not None else assistant.opener or "").strip()
|
||||
if not opener_text:
|
||||
raise HTTPException(status_code=400, detail="Opener text is empty")
|
||||
|
||||
tts_cfg, voice = _resolve_tts_runtime_for_assistant(db, assistant)
|
||||
provider = str(tts_cfg.get("provider") or "").strip().lower()
|
||||
if provider not in {"openai_compatible", "dashscope"}:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported provider for preloaded opener audio: {provider or 'unknown'}")
|
||||
|
||||
speed = float(tts_cfg.get("speed") or assistant.speed or 1.0)
|
||||
voice_key = str(tts_cfg.get("voice") or "").strip()
|
||||
model = str(tts_cfg.get("model") or "").strip() or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
||||
api_key = str(tts_cfg.get("apiKey") or "").strip()
|
||||
base_url = str(tts_cfg.get("baseUrl") or "").strip()
|
||||
|
||||
if provider == "openai_compatible":
|
||||
if not api_key:
|
||||
if voice and voice.api_key:
|
||||
api_key = voice.api_key.strip()
|
||||
if not api_key:
|
||||
api_key = (os.getenv("SILICONFLOW_API_KEY", "") or os.getenv("TTS_API_KEY", "")).strip()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="TTS API key is missing")
|
||||
if not base_url:
|
||||
base_url = OPENAI_COMPATIBLE_DEFAULT_BASE_URL
|
||||
wav_bytes = _synthesize_openai_compatible_wav(
|
||||
text=opener_text,
|
||||
model=model,
|
||||
voice_key=voice_key,
|
||||
speed=speed,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
else:
|
||||
from .voices import _synthesize_dashscope_preview, DASHSCOPE_DEFAULT_BASE_URL, DASHSCOPE_DEFAULT_MODEL, DASHSCOPE_DEFAULT_VOICE_KEY
|
||||
if not api_key:
|
||||
if voice and voice.api_key:
|
||||
api_key = voice.api_key.strip()
|
||||
if not api_key:
|
||||
api_key = (os.getenv("DASHSCOPE_API_KEY", "") or os.getenv("TTS_API_KEY", "")).strip()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="DashScope API key is missing")
|
||||
if not base_url:
|
||||
base_url = DASHSCOPE_DEFAULT_BASE_URL
|
||||
if not model:
|
||||
model = DASHSCOPE_DEFAULT_MODEL
|
||||
if not voice_key:
|
||||
voice_key = DASHSCOPE_DEFAULT_VOICE_KEY
|
||||
try:
|
||||
wav_bytes = _synthesize_dashscope_preview(
|
||||
text=opener_text,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
voice_key=voice_key,
|
||||
speed=speed,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"DashScope opener audio generation failed: {exc}") from exc
|
||||
|
||||
pcm_bytes, duration_ms = _wav_to_pcm16_mono_16k(wav_bytes)
|
||||
record = _ensure_assistant_opener_audio(db, assistant)
|
||||
record.enabled = True
|
||||
record.file_path = _persist_opener_audio_pcm(assistant.id, pcm_bytes)
|
||||
record.encoding = "pcm_s16le"
|
||||
record.sample_rate_hz = 16000
|
||||
record.channels = 1
|
||||
record.duration_ms = duration_ms
|
||||
record.text_hash = hashlib.sha256(opener_text.encode("utf-8")).hexdigest()
|
||||
record.tts_fingerprint = _tts_fingerprint(tts_cfg, opener_text)
|
||||
now = datetime.utcnow()
|
||||
if not record.created_at:
|
||||
record.created_at = now
|
||||
record.updated_at = now
|
||||
assistant.updated_at = now
|
||||
db.commit()
|
||||
db.refresh(assistant)
|
||||
return _opener_audio_out(assistant.opener_audio)
|
||||
|
||||
|
||||
@router.put("/{id}")
|
||||
def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_db)):
|
||||
"""更新助手"""
|
||||
@@ -327,7 +592,12 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
opener_audio_enabled = update_data.pop("openerAudioEnabled", None)
|
||||
_apply_assistant_update(assistant, update_data)
|
||||
if opener_audio_enabled is not None:
|
||||
record = _ensure_assistant_opener_audio(db, assistant)
|
||||
record.enabled = bool(opener_audio_enabled)
|
||||
record.updated_at = datetime.utcnow()
|
||||
|
||||
assistant.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
Reference in New Issue
Block a user