Files
AI-VideoAssistant/api/app/routers/assistants.py
2026-02-26 03:54:52 +08:00

347 lines
13 KiB
Python

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Any, Dict, List, Optional
import uuid
from datetime import datetime
from ..db import get_db
from ..models import Assistant, LLMModel, ASRModel, Voice
from ..schemas import (
AssistantCreate, AssistantUpdate, AssistantOut, AssistantEngineConfigResponse
)
router = APIRouter(prefix="/assistants", tags=["Assistants"])
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
OPENAI_COMPATIBLE_KNOWN_VOICES = {
"alex",
"anna",
"bella",
"benjamin",
"charles",
"claire",
"david",
"diana",
}
def _is_openai_compatible_vendor(vendor: Optional[str]) -> bool:
return (vendor or "").strip().lower() in {
"siliconflow",
"硅基流动",
"openai compatible",
"openai-compatible",
}
def _is_dashscope_vendor(vendor: Optional[str]) -> bool:
return (vendor or "").strip().lower() in {
"dashscope",
}
def _normalize_openai_compatible_voice_key(voice_value: str, model: str) -> str:
raw = (voice_value or "").strip()
model_name = (model or "").strip() or OPENAI_COMPATIBLE_DEFAULT_MODEL
if not raw:
return f"{model_name}:anna"
if ":" in raw:
voice_model, voice_id = raw.split(":", 1)
voice_model = voice_model.strip() or model_name
voice_id = voice_id.strip()
if voice_id.lower() in OPENAI_COMPATIBLE_KNOWN_VOICES:
voice_id = voice_id.lower()
return f"{voice_model}:{voice_id}"
voice_id = raw.lower() if raw.lower() in OPENAI_COMPATIBLE_KNOWN_VOICES else raw
return f"{model_name}:{voice_id}"
def _config_version_id(assistant: Assistant) -> str:
updated = assistant.updated_at or assistant.created_at or datetime.utcnow()
return f"asst_{assistant.id}_{updated.strftime('%Y%m%d%H%M%S')}"
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]:
metadata: Dict[str, Any] = {
"systemPrompt": assistant.prompt or "",
"firstTurnMode": assistant.first_turn_mode or "bot_first",
"greeting": assistant.opener or "",
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
"output": {"mode": "audio" if assistant.voice_output_enabled else "text"},
"bargeIn": {
"enabled": not bool(assistant.bot_cannot_be_interrupted),
"minDurationMs": int(assistant.interruption_sensitivity or 500),
},
"services": {},
"tools": assistant.tools or [],
"history": {
"assistantId": assistant.id,
"userId": int(assistant.user_id or 1),
"source": "debug",
},
}
warnings: List[str] = []
config_mode = str(assistant.config_mode or "platform").strip().lower()
if config_mode in {"dify", "fastgpt"}:
metadata["services"]["llm"] = {
"provider": "openai",
"model": "",
"apiKey": assistant.api_key,
"baseUrl": assistant.api_url,
}
if not (assistant.api_url or "").strip():
warnings.append(f"External LLM API URL is empty for mode: {assistant.config_mode}")
if not (assistant.api_key or "").strip():
warnings.append(f"External LLM API key is empty for mode: {assistant.config_mode}")
elif assistant.llm_model_id:
llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first()
if llm:
metadata["services"]["llm"] = {
"provider": "openai",
"model": llm.model_name or llm.name,
"apiKey": llm.api_key,
"baseUrl": llm.base_url,
}
else:
warnings.append(f"LLM model not found: {assistant.llm_model_id}")
if assistant.asr_model_id:
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
if asr:
asr_provider = "openai_compatible" if _is_openai_compatible_vendor(asr.vendor) else "buffered"
metadata["services"]["asr"] = {
"provider": asr_provider,
"model": asr.model_name or asr.name,
"apiKey": asr.api_key if asr_provider == "openai_compatible" else None,
"baseUrl": asr.base_url if asr_provider == "openai_compatible" else None,
}
else:
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
if not assistant.voice_output_enabled:
metadata["services"]["tts"] = {"enabled": False}
elif assistant.voice:
voice = db.query(Voice).filter(Voice.id == assistant.voice).first()
if voice:
if _is_dashscope_vendor(voice.vendor):
tts_provider = "dashscope"
elif _is_openai_compatible_vendor(voice.vendor):
tts_provider = "openai_compatible"
else:
tts_provider = "edge"
model = voice.model
runtime_voice = voice.voice_key or voice.id
if tts_provider == "openai_compatible":
model = model or OPENAI_COMPATIBLE_DEFAULT_MODEL
runtime_voice = _normalize_openai_compatible_voice_key(runtime_voice, model)
metadata["services"]["tts"] = {
"enabled": True,
"provider": tts_provider,
"model": model,
"apiKey": voice.api_key if tts_provider in {"openai_compatible", "dashscope"} else None,
"baseUrl": voice.base_url if tts_provider in {"openai_compatible", "dashscope"} else None,
"voice": runtime_voice,
"speed": assistant.speed or voice.speed,
}
else:
# Keep assistant.voice as direct voice identifier fallback
metadata["services"]["tts"] = {
"enabled": True,
"voice": assistant.voice,
"speed": assistant.speed or 1.0,
}
warnings.append(f"Voice resource not found: {assistant.voice}")
if assistant.knowledge_base_id:
metadata["knowledgeBaseId"] = assistant.knowledge_base_id
metadata["knowledge"] = {
"enabled": True,
"kbId": assistant.knowledge_base_id,
"nResults": 5,
}
return metadata, warnings
def _build_engine_assistant_config(db: Session, assistant: Assistant) -> Dict[str, Any]:
session_metadata, warnings = _resolve_runtime_metadata(db, assistant)
config_version_id = _config_version_id(assistant)
assistant_cfg = dict(session_metadata)
assistant_cfg["assistantId"] = assistant.id
assistant_cfg["configVersionId"] = config_version_id
return {
"assistantId": assistant.id,
"configVersionId": config_version_id,
"assistant": assistant_cfg,
"sessionStartMetadata": session_metadata,
"sources": {
"llmModelId": assistant.llm_model_id,
"asrModelId": assistant.asr_model_id,
"voiceId": assistant.voice,
"knowledgeBaseId": assistant.knowledge_base_id,
},
"warnings": warnings,
}
def assistant_to_dict(assistant: Assistant) -> dict:
return {
"id": assistant.id,
"name": assistant.name,
"callCount": assistant.call_count,
"firstTurnMode": assistant.first_turn_mode or "bot_first",
"opener": assistant.opener or "",
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
"prompt": assistant.prompt or "",
"knowledgeBaseId": assistant.knowledge_base_id,
"language": assistant.language,
"voiceOutputEnabled": assistant.voice_output_enabled,
"voice": assistant.voice,
"speed": assistant.speed,
"hotwords": assistant.hotwords or [],
"tools": assistant.tools or [],
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
"interruptionSensitivity": assistant.interruption_sensitivity,
"configMode": assistant.config_mode,
"apiUrl": assistant.api_url,
"apiKey": assistant.api_key,
"llmModelId": assistant.llm_model_id,
"asrModelId": assistant.asr_model_id,
"embeddingModelId": assistant.embedding_model_id,
"rerankModelId": assistant.rerank_model_id,
"created_at": assistant.created_at,
"updated_at": assistant.updated_at,
}
def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
field_map = {
"knowledgeBaseId": "knowledge_base_id",
"firstTurnMode": "first_turn_mode",
"interruptionSensitivity": "interruption_sensitivity",
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
"configMode": "config_mode",
"voiceOutputEnabled": "voice_output_enabled",
"generatedOpenerEnabled": "generated_opener_enabled",
"apiUrl": "api_url",
"apiKey": "api_key",
"llmModelId": "llm_model_id",
"asrModelId": "asr_model_id",
"embeddingModelId": "embedding_model_id",
"rerankModelId": "rerank_model_id",
}
for field, value in update_data.items():
setattr(assistant, field_map.get(field, field), value)
# ============ Assistants ============
@router.get("")
def list_assistants(
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取助手列表"""
query = db.query(Assistant)
total = query.count()
assistants = query.order_by(Assistant.created_at.desc()) \
.offset((page-1)*limit).limit(limit).all()
return {
"total": total,
"page": page,
"limit": limit,
"list": [assistant_to_dict(a) for a in assistants]
}
@router.get("/{id}", response_model=AssistantOut)
def get_assistant(id: str, db: Session = Depends(get_db)):
"""获取单个助手详情"""
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return assistant_to_dict(assistant)
@router.get("/{id}/config", response_model=AssistantEngineConfigResponse)
def get_assistant_config(id: str, db: Session = Depends(get_db)):
"""Canonical engine config endpoint consumed by engine backend adapter."""
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return _build_engine_assistant_config(db, assistant)
@router.get("/{id}/runtime-config", response_model=AssistantEngineConfigResponse)
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
"""Legacy alias for resolved engine runtime config."""
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return _build_engine_assistant_config(db, assistant)
@router.post("", response_model=AssistantOut)
def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
"""创建新助手"""
assistant = Assistant(
id=str(uuid.uuid4())[:8],
user_id=1, # 默认用户,后续添加认证
name=data.name,
first_turn_mode=data.firstTurnMode,
opener=data.opener,
generated_opener_enabled=data.generatedOpenerEnabled,
prompt=data.prompt,
knowledge_base_id=data.knowledgeBaseId,
language=data.language,
voice_output_enabled=data.voiceOutputEnabled,
voice=data.voice,
speed=data.speed,
hotwords=data.hotwords,
tools=data.tools,
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
interruption_sensitivity=data.interruptionSensitivity,
config_mode=data.configMode,
api_url=data.apiUrl,
api_key=data.apiKey,
llm_model_id=data.llmModelId,
asr_model_id=data.asrModelId,
embedding_model_id=data.embeddingModelId,
rerank_model_id=data.rerankModelId,
)
db.add(assistant)
db.commit()
db.refresh(assistant)
return assistant_to_dict(assistant)
@router.put("/{id}")
def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_db)):
"""更新助手"""
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
update_data = data.model_dump(exclude_unset=True)
_apply_assistant_update(assistant, update_data)
assistant.updated_at = datetime.utcnow()
db.commit()
db.refresh(assistant)
return assistant_to_dict(assistant)
@router.delete("/{id}")
def delete_assistant(id: str, db: Session = Depends(get_db)):
"""删除助手"""
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
db.delete(assistant)
db.commit()
return {"message": "Deleted successfully"}