- Introduced `asr_interim_enabled` field in the Assistant model to control interim ASR results. - Updated AssistantBase and AssistantUpdate schemas to include the new field. - Modified the database schema to add the `asr_interim_enabled` column. - Enhanced runtime metadata to reflect interim ASR settings. - Updated API endpoints and tests to validate the new functionality. - Adjusted documentation to include details about interim ASR results configuration.
842 lines
32 KiB
Python
842 lines
32 KiB
Python
import audioop
|
|
import hashlib
|
|
import io
|
|
import os
|
|
import wave
|
|
from pathlib import Path
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi.responses import FileResponse
|
|
from sqlalchemy import inspect, text
|
|
from sqlalchemy.orm import Session
|
|
from typing import Any, Dict, List, Optional
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from ..db import get_db
|
|
from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice, ToolResource
|
|
from ..schemas import (
|
|
AssistantCreate,
|
|
AssistantUpdate,
|
|
AssistantOut,
|
|
AssistantEngineConfigResponse,
|
|
AssistantOpenerAudioGenerateRequest,
|
|
AssistantOpenerAudioOut,
|
|
)
|
|
from .tools import (
|
|
TOOL_REGISTRY,
|
|
TOOL_CATEGORY_MAP,
|
|
TOOL_PARAMETER_DEFAULTS,
|
|
TOOL_WAIT_FOR_RESPONSE_DEFAULTS,
|
|
normalize_tool_id,
|
|
_ensure_tool_resource_schema,
|
|
)
|
|
|
|
router = APIRouter(prefix="/assistants", tags=["Assistants"])
|
|
|
|
OPENAI_COMPATIBLE_DEFAULT_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
|
OPENAI_COMPATIBLE_DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1"
|
|
OPENER_AUDIO_DIR = Path(__file__).resolve().parents[2] / "data" / "opener_audio"
|
|
OPENAI_COMPATIBLE_KNOWN_VOICES = {
|
|
"alex",
|
|
"anna",
|
|
"bella",
|
|
"benjamin",
|
|
"charles",
|
|
"claire",
|
|
"david",
|
|
"diana",
|
|
}
|
|
|
|
|
|
def _is_openai_compatible_vendor(vendor: Optional[str]) -> bool:
|
|
return (vendor or "").strip().lower() in {
|
|
"siliconflow",
|
|
"硅基流动",
|
|
"openai compatible",
|
|
"openai-compatible",
|
|
}
|
|
|
|
|
|
def _is_dashscope_vendor(vendor: Optional[str]) -> bool:
|
|
return (vendor or "").strip().lower() in {
|
|
"dashscope",
|
|
}
|
|
|
|
|
|
def _normalize_openai_compatible_voice_key(voice_value: str, model: str) -> str:
|
|
raw = (voice_value or "").strip()
|
|
model_name = (model or "").strip() or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
|
if not raw:
|
|
return f"{model_name}:anna"
|
|
|
|
if ":" in raw:
|
|
voice_model, voice_id = raw.split(":", 1)
|
|
voice_model = voice_model.strip() or model_name
|
|
voice_id = voice_id.strip()
|
|
if voice_id.lower() in OPENAI_COMPATIBLE_KNOWN_VOICES:
|
|
voice_id = voice_id.lower()
|
|
return f"{voice_model}:{voice_id}"
|
|
|
|
voice_id = raw.lower() if raw.lower() in OPENAI_COMPATIBLE_KNOWN_VOICES else raw
|
|
return f"{model_name}:{voice_id}"
|
|
|
|
|
|
def _config_version_id(assistant: Assistant) -> str:
|
|
updated = assistant.updated_at or assistant.created_at or datetime.utcnow()
|
|
return f"asst_{assistant.id}_{updated.strftime('%Y%m%d%H%M%S')}"
|
|
|
|
|
|
def _normalize_runtime_tool_schema(tool_id: str, raw_schema: Any) -> Dict[str, Any]:
|
|
schema = dict(raw_schema) if isinstance(raw_schema, dict) else {}
|
|
if not schema:
|
|
fallback = TOOL_REGISTRY.get(tool_id, {}).get("parameters")
|
|
if isinstance(fallback, dict):
|
|
schema = dict(fallback)
|
|
schema.setdefault("type", "object")
|
|
if not isinstance(schema.get("properties"), dict):
|
|
schema["properties"] = {}
|
|
required = schema.get("required")
|
|
if required is None or not isinstance(required, list):
|
|
schema["required"] = []
|
|
return schema
|
|
|
|
|
|
def _compose_runtime_system_prompt(base_prompt: Optional[str]) -> str:
|
|
raw = str(base_prompt or "").strip()
|
|
tool_policy = (
|
|
"Tool usage policy:\n"
|
|
"- Tool function names/IDs are internal and must never be shown to users.\n"
|
|
"- When users ask which tools are available, describe capabilities in natural language.\n"
|
|
"- Do not expose raw tool call payloads, IDs, or executor details."
|
|
)
|
|
return f"{raw}\n\n{tool_policy}" if raw else tool_policy
|
|
|
|
|
|
def _ensure_assistant_schema(db: Session) -> None:
|
|
"""Apply lightweight SQLite migrations for newly added assistants columns."""
|
|
bind = db.get_bind()
|
|
inspector = inspect(bind)
|
|
try:
|
|
columns = {col["name"] for col in inspector.get_columns("assistants")}
|
|
except Exception:
|
|
return
|
|
|
|
altered = False
|
|
if "manual_opener_tool_calls" not in columns:
|
|
db.execute(text("ALTER TABLE assistants ADD COLUMN manual_opener_tool_calls JSON"))
|
|
altered = True
|
|
if "asr_interim_enabled" not in columns:
|
|
db.execute(text("ALTER TABLE assistants ADD COLUMN asr_interim_enabled BOOLEAN DEFAULT 0"))
|
|
altered = True
|
|
|
|
if altered:
|
|
db.commit()
|
|
|
|
|
|
def _normalize_manual_opener_tool_calls(raw: Any, warnings: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
|
normalized: List[Dict[str, Any]] = []
|
|
if not isinstance(raw, list):
|
|
return normalized
|
|
|
|
for idx, item in enumerate(raw):
|
|
if not isinstance(item, dict):
|
|
if warnings is not None:
|
|
warnings.append(f"Ignored invalid manual opener tool call at index {idx}: not an object")
|
|
continue
|
|
|
|
tool_name = normalize_tool_id(str(
|
|
item.get("toolName")
|
|
or item.get("tool_name")
|
|
or item.get("name")
|
|
or ""
|
|
).strip())
|
|
if not tool_name:
|
|
if warnings is not None:
|
|
warnings.append(f"Ignored invalid manual opener tool call at index {idx}: missing toolName")
|
|
continue
|
|
|
|
args_raw = item.get("arguments")
|
|
args: Dict[str, Any] = {}
|
|
if isinstance(args_raw, dict):
|
|
args = dict(args_raw)
|
|
elif isinstance(args_raw, str):
|
|
text_value = args_raw.strip()
|
|
if text_value:
|
|
try:
|
|
parsed = json.loads(text_value)
|
|
if isinstance(parsed, dict):
|
|
args = parsed
|
|
else:
|
|
if warnings is not None:
|
|
warnings.append(
|
|
f"Ignored non-object arguments for manual opener tool call '{tool_name}' at index {idx}"
|
|
)
|
|
except Exception:
|
|
if warnings is not None:
|
|
warnings.append(f"Ignored invalid JSON arguments for manual opener tool call '{tool_name}' at index {idx}")
|
|
elif args_raw is not None and warnings is not None:
|
|
warnings.append(f"Ignored unsupported arguments type for manual opener tool call '{tool_name}' at index {idx}")
|
|
|
|
normalized.append({"toolName": tool_name, "arguments": args})
|
|
|
|
# Keep opener sequence intentionally short to avoid long pre-dialog delays.
|
|
return normalized[:8]
|
|
|
|
|
|
def _normalize_assistant_tool_ids(raw: Any) -> List[str]:
|
|
if not isinstance(raw, list):
|
|
return []
|
|
normalized: List[str] = []
|
|
seen: set[str] = set()
|
|
for item in raw:
|
|
tool_id = normalize_tool_id(item)
|
|
if not tool_id or tool_id in seen:
|
|
continue
|
|
seen.add(tool_id)
|
|
normalized.append(tool_id)
|
|
return normalized
|
|
|
|
|
|
def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings: List[str]) -> List[Dict[str, Any]]:
|
|
_ensure_tool_resource_schema(db)
|
|
ids = _normalize_assistant_tool_ids(selected_tool_ids)
|
|
if not ids:
|
|
return []
|
|
|
|
resources = (
|
|
db.query(ToolResource)
|
|
.filter(ToolResource.id.in_(ids))
|
|
.all()
|
|
)
|
|
by_id = {str(item.id): item for item in resources}
|
|
|
|
runtime_tools: List[Dict[str, Any]] = []
|
|
for tool_id in ids:
|
|
resource = by_id.get(tool_id)
|
|
if resource and resource.enabled is False:
|
|
warnings.append(f"Tool is disabled and skipped in runtime config: {tool_id}")
|
|
continue
|
|
|
|
category = str(resource.category if resource else TOOL_CATEGORY_MAP.get(tool_id, "query"))
|
|
display_name = (
|
|
str(resource.name or tool_id).strip()
|
|
if resource
|
|
else str(TOOL_REGISTRY.get(tool_id, {}).get("name") or tool_id).strip()
|
|
)
|
|
description = (
|
|
str(resource.description or resource.name or "").strip()
|
|
if resource
|
|
else str(TOOL_REGISTRY.get(tool_id, {}).get("description") or "").strip()
|
|
)
|
|
schema = _normalize_runtime_tool_schema(
|
|
tool_id,
|
|
resource.parameter_schema if resource else TOOL_REGISTRY.get(tool_id, {}).get("parameters"),
|
|
)
|
|
defaults_raw = resource.parameter_defaults if resource else TOOL_PARAMETER_DEFAULTS.get(tool_id)
|
|
defaults = dict(defaults_raw) if isinstance(defaults_raw, dict) else {}
|
|
wait_for_response = (
|
|
bool(resource.wait_for_response)
|
|
if resource
|
|
else bool(TOOL_WAIT_FOR_RESPONSE_DEFAULTS.get(tool_id, False))
|
|
)
|
|
|
|
if not resource and tool_id not in TOOL_REGISTRY:
|
|
warnings.append(f"Tool resource not found: {tool_id}")
|
|
|
|
runtime_tool: Dict[str, Any] = {
|
|
"type": "function",
|
|
"executor": "client" if category == "system" else "server",
|
|
"function": {
|
|
"name": tool_id,
|
|
"description": (
|
|
f"Display name: {display_name}. {description}".strip()
|
|
if display_name
|
|
else (description or tool_id)
|
|
),
|
|
"parameters": schema,
|
|
},
|
|
"displayName": display_name or tool_id,
|
|
"toolId": tool_id,
|
|
"waitForResponse": wait_for_response,
|
|
}
|
|
if defaults:
|
|
runtime_tool["defaultArgs"] = defaults
|
|
runtime_tools.append(runtime_tool)
|
|
|
|
return runtime_tools
|
|
|
|
|
|
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]:
|
|
warnings: List[str] = []
|
|
generated_opener_enabled = bool(assistant.generated_opener_enabled)
|
|
manual_opener_tool_calls = _normalize_manual_opener_tool_calls(
|
|
assistant.manual_opener_tool_calls,
|
|
warnings=warnings,
|
|
)
|
|
metadata: Dict[str, Any] = {
|
|
"systemPrompt": _compose_runtime_system_prompt(assistant.prompt),
|
|
"firstTurnMode": assistant.first_turn_mode or "bot_first",
|
|
# Generated opener should rely on systemPrompt instead of fixed opener text.
|
|
"greeting": "" if generated_opener_enabled else (assistant.opener or ""),
|
|
"generatedOpenerEnabled": generated_opener_enabled,
|
|
"manualOpenerToolCalls": manual_opener_tool_calls,
|
|
"output": {"mode": "audio" if assistant.voice_output_enabled else "text"},
|
|
"bargeIn": {
|
|
"enabled": not bool(assistant.bot_cannot_be_interrupted),
|
|
"minDurationMs": int(assistant.interruption_sensitivity or 500),
|
|
},
|
|
"services": {},
|
|
"tools": _resolve_runtime_tools(db, assistant.tools or [], warnings),
|
|
"history": {
|
|
"assistantId": assistant.id,
|
|
"userId": int(assistant.user_id or 1),
|
|
"source": "debug",
|
|
},
|
|
}
|
|
|
|
config_mode = str(assistant.config_mode or "platform").strip().lower()
|
|
|
|
if config_mode in {"dify", "fastgpt"}:
|
|
metadata["services"]["llm"] = {
|
|
"provider": "openai",
|
|
"model": "",
|
|
"apiKey": assistant.api_key,
|
|
"baseUrl": assistant.api_url,
|
|
}
|
|
if not (assistant.api_url or "").strip():
|
|
warnings.append(f"External LLM API URL is empty for mode: {assistant.config_mode}")
|
|
if not (assistant.api_key or "").strip():
|
|
warnings.append(f"External LLM API key is empty for mode: {assistant.config_mode}")
|
|
elif assistant.llm_model_id:
|
|
llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first()
|
|
if llm:
|
|
metadata["services"]["llm"] = {
|
|
"provider": "openai",
|
|
"model": llm.model_name or llm.name,
|
|
"apiKey": llm.api_key,
|
|
"baseUrl": llm.base_url,
|
|
}
|
|
else:
|
|
warnings.append(f"LLM model not found: {assistant.llm_model_id}")
|
|
|
|
asr_runtime: Dict[str, Any] = {
|
|
"enableInterim": bool(assistant.asr_interim_enabled),
|
|
}
|
|
if assistant.asr_model_id:
|
|
asr = db.query(ASRModel).filter(ASRModel.id == assistant.asr_model_id).first()
|
|
if asr:
|
|
if _is_dashscope_vendor(asr.vendor):
|
|
asr_provider = "dashscope"
|
|
elif _is_openai_compatible_vendor(asr.vendor):
|
|
asr_provider = "openai_compatible"
|
|
else:
|
|
asr_provider = "buffered"
|
|
asr_runtime.update({
|
|
"provider": asr_provider,
|
|
"model": asr.model_name or asr.name,
|
|
"apiKey": asr.api_key if asr_provider in {"openai_compatible", "dashscope"} else None,
|
|
"baseUrl": asr.base_url if asr_provider in {"openai_compatible", "dashscope"} else None,
|
|
})
|
|
else:
|
|
warnings.append(f"ASR model not found: {assistant.asr_model_id}")
|
|
metadata["services"]["asr"] = asr_runtime
|
|
|
|
if not assistant.voice_output_enabled:
|
|
metadata["services"]["tts"] = {"enabled": False}
|
|
elif assistant.voice:
|
|
voice = db.query(Voice).filter(Voice.id == assistant.voice).first()
|
|
if voice:
|
|
if _is_dashscope_vendor(voice.vendor):
|
|
tts_provider = "dashscope"
|
|
elif _is_openai_compatible_vendor(voice.vendor):
|
|
tts_provider = "openai_compatible"
|
|
else:
|
|
tts_provider = "edge"
|
|
model = voice.model
|
|
runtime_voice = voice.voice_key or voice.id
|
|
if tts_provider == "openai_compatible":
|
|
model = model or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
|
runtime_voice = _normalize_openai_compatible_voice_key(runtime_voice, model)
|
|
metadata["services"]["tts"] = {
|
|
"enabled": True,
|
|
"provider": tts_provider,
|
|
"model": model,
|
|
"apiKey": voice.api_key if tts_provider in {"openai_compatible", "dashscope"} else None,
|
|
"baseUrl": voice.base_url if tts_provider in {"openai_compatible", "dashscope"} else None,
|
|
"voice": runtime_voice,
|
|
"speed": assistant.speed or voice.speed,
|
|
}
|
|
else:
|
|
# Keep assistant.voice as direct voice identifier fallback
|
|
metadata["services"]["tts"] = {
|
|
"enabled": True,
|
|
"voice": assistant.voice,
|
|
"speed": assistant.speed or 1.0,
|
|
}
|
|
warnings.append(f"Voice resource not found: {assistant.voice}")
|
|
|
|
if assistant.knowledge_base_id:
|
|
metadata["knowledgeBaseId"] = assistant.knowledge_base_id
|
|
metadata["knowledge"] = {
|
|
"enabled": True,
|
|
"kbId": assistant.knowledge_base_id,
|
|
"nResults": 5,
|
|
}
|
|
opener_audio = assistant.opener_audio
|
|
opener_audio_ready = bool(opener_audio and opener_audio.file_path and Path(opener_audio.file_path).exists())
|
|
metadata["openerAudio"] = {
|
|
"enabled": bool(opener_audio.enabled) if opener_audio else False,
|
|
"ready": opener_audio_ready,
|
|
"encoding": opener_audio.encoding if opener_audio else "pcm_s16le",
|
|
"sampleRateHz": int(opener_audio.sample_rate_hz) if opener_audio else 16000,
|
|
"channels": int(opener_audio.channels) if opener_audio else 1,
|
|
"durationMs": int(opener_audio.duration_ms) if opener_audio else 0,
|
|
"textHash": opener_audio.text_hash if opener_audio else None,
|
|
"ttsFingerprint": opener_audio.tts_fingerprint if opener_audio else None,
|
|
"pcmUrl": f"/api/assistants/{assistant.id}/opener-audio/pcm" if opener_audio_ready else None,
|
|
}
|
|
return metadata, warnings
|
|
|
|
|
|
def _build_engine_assistant_config(db: Session, assistant: Assistant) -> Dict[str, Any]:
|
|
session_metadata, warnings = _resolve_runtime_metadata(db, assistant)
|
|
config_version_id = _config_version_id(assistant)
|
|
assistant_cfg = dict(session_metadata)
|
|
assistant_cfg["assistantId"] = assistant.id
|
|
assistant_cfg["configVersionId"] = config_version_id
|
|
|
|
return {
|
|
"assistantId": assistant.id,
|
|
"configVersionId": config_version_id,
|
|
"assistant": assistant_cfg,
|
|
"sessionStartMetadata": session_metadata,
|
|
"sources": {
|
|
"llmModelId": assistant.llm_model_id,
|
|
"asrModelId": assistant.asr_model_id,
|
|
"voiceId": assistant.voice,
|
|
"knowledgeBaseId": assistant.knowledge_base_id,
|
|
},
|
|
"warnings": warnings,
|
|
}
|
|
|
|
|
|
def assistant_to_dict(assistant: Assistant) -> dict:
|
|
opener_audio = assistant.opener_audio
|
|
opener_audio_ready = bool(opener_audio and opener_audio.file_path and Path(opener_audio.file_path).exists())
|
|
return {
|
|
"id": assistant.id,
|
|
"name": assistant.name,
|
|
"callCount": assistant.call_count,
|
|
"firstTurnMode": assistant.first_turn_mode or "bot_first",
|
|
"opener": assistant.opener or "",
|
|
"manualOpenerToolCalls": _normalize_manual_opener_tool_calls(assistant.manual_opener_tool_calls),
|
|
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
|
|
"openerAudioEnabled": bool(opener_audio.enabled) if opener_audio else False,
|
|
"openerAudioReady": opener_audio_ready,
|
|
"openerAudioDurationMs": int(opener_audio.duration_ms) if opener_audio else 0,
|
|
"openerAudioUpdatedAt": opener_audio.updated_at if opener_audio else None,
|
|
"prompt": assistant.prompt or "",
|
|
"knowledgeBaseId": assistant.knowledge_base_id,
|
|
"language": assistant.language,
|
|
"voiceOutputEnabled": assistant.voice_output_enabled,
|
|
"voice": assistant.voice,
|
|
"speed": assistant.speed,
|
|
"hotwords": assistant.hotwords or [],
|
|
"tools": _normalize_assistant_tool_ids(assistant.tools),
|
|
"asrInterimEnabled": bool(assistant.asr_interim_enabled),
|
|
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
|
|
"interruptionSensitivity": assistant.interruption_sensitivity,
|
|
"configMode": assistant.config_mode,
|
|
"apiUrl": assistant.api_url,
|
|
"apiKey": assistant.api_key,
|
|
"llmModelId": assistant.llm_model_id,
|
|
"asrModelId": assistant.asr_model_id,
|
|
"embeddingModelId": assistant.embedding_model_id,
|
|
"rerankModelId": assistant.rerank_model_id,
|
|
"created_at": assistant.created_at,
|
|
"updated_at": assistant.updated_at,
|
|
}
|
|
|
|
|
|
def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
|
field_map = {
|
|
"knowledgeBaseId": "knowledge_base_id",
|
|
"firstTurnMode": "first_turn_mode",
|
|
"manualOpenerToolCalls": "manual_opener_tool_calls",
|
|
"interruptionSensitivity": "interruption_sensitivity",
|
|
"asrInterimEnabled": "asr_interim_enabled",
|
|
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
|
|
"configMode": "config_mode",
|
|
"voiceOutputEnabled": "voice_output_enabled",
|
|
"generatedOpenerEnabled": "generated_opener_enabled",
|
|
"apiUrl": "api_url",
|
|
"apiKey": "api_key",
|
|
"llmModelId": "llm_model_id",
|
|
"asrModelId": "asr_model_id",
|
|
"embeddingModelId": "embedding_model_id",
|
|
"rerankModelId": "rerank_model_id",
|
|
}
|
|
for field, value in update_data.items():
|
|
setattr(assistant, field_map.get(field, field), value)
|
|
|
|
|
|
def _ensure_assistant_opener_audio(db: Session, assistant: Assistant) -> AssistantOpenerAudio:
|
|
record = assistant.opener_audio
|
|
if record:
|
|
return record
|
|
record = AssistantOpenerAudio(assistant_id=assistant.id, enabled=False)
|
|
db.add(record)
|
|
db.flush()
|
|
return record
|
|
|
|
|
|
def _resolve_tts_runtime_for_assistant(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], Optional[Voice]]:
|
|
metadata, _ = _resolve_runtime_metadata(db, assistant)
|
|
services = metadata.get("services") if isinstance(metadata.get("services"), dict) else {}
|
|
tts = services.get("tts") if isinstance(services, dict) and isinstance(services.get("tts"), dict) else {}
|
|
voice = db.query(Voice).filter(Voice.id == assistant.voice).first() if assistant.voice else None
|
|
return tts, voice
|
|
|
|
|
|
def _tts_fingerprint(tts_cfg: Dict[str, Any], opener_text: str) -> str:
|
|
identity = {
|
|
"provider": tts_cfg.get("provider"),
|
|
"model": tts_cfg.get("model"),
|
|
"voice": tts_cfg.get("voice"),
|
|
"speed": tts_cfg.get("speed"),
|
|
"text": opener_text,
|
|
}
|
|
return hashlib.sha256(str(identity).encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _synthesize_openai_compatible_wav(
|
|
*,
|
|
text: str,
|
|
model: str,
|
|
voice_key: str,
|
|
speed: float,
|
|
api_key: str,
|
|
base_url: str,
|
|
) -> bytes:
|
|
payload = {
|
|
"model": model or OPENAI_COMPATIBLE_DEFAULT_MODEL,
|
|
"input": text,
|
|
"voice": voice_key,
|
|
"response_format": "wav",
|
|
"speed": speed,
|
|
}
|
|
with httpx.Client(timeout=45.0) as client:
|
|
response = client.post(
|
|
f"{base_url.rstrip('/')}/audio/speech",
|
|
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
|
json=payload,
|
|
)
|
|
if response.status_code != 200:
|
|
detail = response.text
|
|
try:
|
|
detail_json = response.json()
|
|
detail = detail_json.get("error", {}).get("message") or detail_json.get("detail") or detail
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=502, detail=f"TTS vendor error: {detail}")
|
|
return response.content
|
|
|
|
|
|
def _wav_to_pcm16_mono_16k(wav_bytes: bytes) -> tuple[bytes, int]:
|
|
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
|
|
channels = wav_file.getnchannels()
|
|
sample_width = wav_file.getsampwidth()
|
|
sample_rate = wav_file.getframerate()
|
|
frames = wav_file.getnframes()
|
|
raw = wav_file.readframes(frames)
|
|
|
|
if sample_width != 2:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported WAV sample width: {sample_width * 8}bit")
|
|
|
|
if channels > 1:
|
|
raw = audioop.tomono(raw, sample_width, 0.5, 0.5)
|
|
|
|
if sample_rate != 16000:
|
|
raw, _ = audioop.ratecv(raw, sample_width, 1, sample_rate, 16000, None)
|
|
|
|
duration_ms = int((len(raw) / (16000 * 2)) * 1000)
|
|
return raw, duration_ms
|
|
|
|
|
|
def _persist_opener_audio_pcm(assistant_id: str, pcm_bytes: bytes) -> str:
|
|
OPENER_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
|
|
file_path = OPENER_AUDIO_DIR / f"{assistant_id}.pcm"
|
|
with open(file_path, "wb") as f:
|
|
f.write(pcm_bytes)
|
|
return str(file_path)
|
|
|
|
|
|
def _opener_audio_out(record: Optional[AssistantOpenerAudio]) -> AssistantOpenerAudioOut:
|
|
if not record:
|
|
return AssistantOpenerAudioOut()
|
|
ready = bool(record.file_path and Path(record.file_path).exists())
|
|
return AssistantOpenerAudioOut(
|
|
enabled=bool(record.enabled),
|
|
ready=ready,
|
|
encoding=record.encoding,
|
|
sample_rate_hz=record.sample_rate_hz,
|
|
channels=record.channels,
|
|
duration_ms=record.duration_ms,
|
|
updated_at=record.updated_at,
|
|
text_hash=record.text_hash,
|
|
tts_fingerprint=record.tts_fingerprint,
|
|
)
|
|
|
|
|
|
# ============ Assistants ============
|
|
@router.get("")
|
|
def list_assistants(
|
|
page: int = 1,
|
|
limit: int = 50,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取助手列表"""
|
|
_ensure_assistant_schema(db)
|
|
query = db.query(Assistant)
|
|
total = query.count()
|
|
assistants = query.order_by(Assistant.created_at.desc()) \
|
|
.offset((page-1)*limit).limit(limit).all()
|
|
return {
|
|
"total": total,
|
|
"page": page,
|
|
"limit": limit,
|
|
"list": [assistant_to_dict(a) for a in assistants]
|
|
}
|
|
|
|
|
|
@router.get("/{id}", response_model=AssistantOut)
|
|
def get_assistant(id: str, db: Session = Depends(get_db)):
|
|
"""获取单个助手详情"""
|
|
_ensure_assistant_schema(db)
|
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
|
if not assistant:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
return assistant_to_dict(assistant)
|
|
|
|
|
|
@router.get("/{id}/config", response_model=AssistantEngineConfigResponse)
|
|
def get_assistant_config(id: str, db: Session = Depends(get_db)):
|
|
"""Canonical engine config endpoint consumed by engine backend adapter."""
|
|
_ensure_assistant_schema(db)
|
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
|
if not assistant:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
return _build_engine_assistant_config(db, assistant)
|
|
|
|
|
|
@router.get("/{id}/runtime-config", response_model=AssistantEngineConfigResponse)
|
|
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
|
|
"""Legacy alias for resolved engine runtime config."""
|
|
_ensure_assistant_schema(db)
|
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
|
if not assistant:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
return _build_engine_assistant_config(db, assistant)
|
|
|
|
|
|
@router.post("", response_model=AssistantOut)
|
|
def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
|
"""创建新助手"""
|
|
_ensure_assistant_schema(db)
|
|
assistant = Assistant(
|
|
id=str(uuid.uuid4())[:8],
|
|
user_id=1, # 默认用户,后续添加认证
|
|
name=data.name,
|
|
first_turn_mode=data.firstTurnMode,
|
|
opener=data.opener,
|
|
manual_opener_tool_calls=_normalize_manual_opener_tool_calls(data.manualOpenerToolCalls),
|
|
generated_opener_enabled=data.generatedOpenerEnabled,
|
|
prompt=data.prompt,
|
|
knowledge_base_id=data.knowledgeBaseId,
|
|
language=data.language,
|
|
voice_output_enabled=data.voiceOutputEnabled,
|
|
voice=data.voice,
|
|
speed=data.speed,
|
|
hotwords=data.hotwords,
|
|
tools=_normalize_assistant_tool_ids(data.tools),
|
|
asr_interim_enabled=data.asrInterimEnabled,
|
|
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
|
|
interruption_sensitivity=data.interruptionSensitivity,
|
|
config_mode=data.configMode,
|
|
api_url=data.apiUrl,
|
|
api_key=data.apiKey,
|
|
llm_model_id=data.llmModelId,
|
|
asr_model_id=data.asrModelId,
|
|
embedding_model_id=data.embeddingModelId,
|
|
rerank_model_id=data.rerankModelId,
|
|
)
|
|
db.add(assistant)
|
|
db.commit()
|
|
db.refresh(assistant)
|
|
opener_audio = _ensure_assistant_opener_audio(db, assistant)
|
|
opener_audio.enabled = bool(data.openerAudioEnabled)
|
|
opener_audio.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
db.refresh(assistant)
|
|
return assistant_to_dict(assistant)
|
|
|
|
|
|
@router.get("/{id}/opener-audio", response_model=AssistantOpenerAudioOut)
|
|
def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)):
|
|
_ensure_assistant_schema(db)
|
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
|
if not assistant:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
return _opener_audio_out(assistant.opener_audio)
|
|
|
|
|
|
@router.get("/{id}/opener-audio/pcm")
|
|
def get_assistant_opener_audio_pcm(id: str, db: Session = Depends(get_db)):
|
|
_ensure_assistant_schema(db)
|
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
|
if not assistant:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
record = assistant.opener_audio
|
|
if not record or not record.file_path:
|
|
raise HTTPException(status_code=404, detail="Opener audio not generated")
|
|
file_path = Path(record.file_path)
|
|
if not file_path.exists():
|
|
raise HTTPException(status_code=404, detail="Opener audio file missing")
|
|
return FileResponse(
|
|
str(file_path),
|
|
media_type="application/octet-stream",
|
|
filename=f"{assistant.id}.pcm",
|
|
)
|
|
|
|
|
|
@router.post("/{id}/opener-audio/generate", response_model=AssistantOpenerAudioOut)
|
|
def generate_assistant_opener_audio(
|
|
id: str,
|
|
data: AssistantOpenerAudioGenerateRequest,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
_ensure_assistant_schema(db)
|
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
|
if not assistant:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
if not assistant.voice_output_enabled:
|
|
raise HTTPException(status_code=400, detail="Voice output is disabled")
|
|
|
|
opener_text = (data.text if data.text is not None else assistant.opener or "").strip()
|
|
if not opener_text:
|
|
raise HTTPException(status_code=400, detail="Opener text is empty")
|
|
|
|
tts_cfg, voice = _resolve_tts_runtime_for_assistant(db, assistant)
|
|
provider = str(tts_cfg.get("provider") or "").strip().lower()
|
|
if provider not in {"openai_compatible", "dashscope"}:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported provider for preloaded opener audio: {provider or 'unknown'}")
|
|
|
|
speed = float(tts_cfg.get("speed") or assistant.speed or 1.0)
|
|
voice_key = str(tts_cfg.get("voice") or "").strip()
|
|
model = str(tts_cfg.get("model") or "").strip() or OPENAI_COMPATIBLE_DEFAULT_MODEL
|
|
api_key = str(tts_cfg.get("apiKey") or "").strip()
|
|
base_url = str(tts_cfg.get("baseUrl") or "").strip()
|
|
|
|
if provider == "openai_compatible":
|
|
if not api_key:
|
|
if voice and voice.api_key:
|
|
api_key = voice.api_key.strip()
|
|
if not api_key:
|
|
api_key = (os.getenv("SILICONFLOW_API_KEY", "") or os.getenv("TTS_API_KEY", "")).strip()
|
|
if not api_key:
|
|
raise HTTPException(status_code=400, detail="TTS API key is missing")
|
|
if not base_url:
|
|
base_url = OPENAI_COMPATIBLE_DEFAULT_BASE_URL
|
|
wav_bytes = _synthesize_openai_compatible_wav(
|
|
text=opener_text,
|
|
model=model,
|
|
voice_key=voice_key,
|
|
speed=speed,
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
)
|
|
else:
|
|
from .voices import _synthesize_dashscope_preview, DASHSCOPE_DEFAULT_BASE_URL, DASHSCOPE_DEFAULT_MODEL, DASHSCOPE_DEFAULT_VOICE_KEY
|
|
if not api_key:
|
|
if voice and voice.api_key:
|
|
api_key = voice.api_key.strip()
|
|
if not api_key:
|
|
api_key = (os.getenv("DASHSCOPE_API_KEY", "") or os.getenv("TTS_API_KEY", "")).strip()
|
|
if not api_key:
|
|
raise HTTPException(status_code=400, detail="DashScope API key is missing")
|
|
if not base_url:
|
|
base_url = DASHSCOPE_DEFAULT_BASE_URL
|
|
if not model:
|
|
model = DASHSCOPE_DEFAULT_MODEL
|
|
if not voice_key:
|
|
voice_key = DASHSCOPE_DEFAULT_VOICE_KEY
|
|
try:
|
|
wav_bytes = _synthesize_dashscope_preview(
|
|
text=opener_text,
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
model=model,
|
|
voice_key=voice_key,
|
|
speed=speed,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=502, detail=f"DashScope opener audio generation failed: {exc}") from exc
|
|
|
|
pcm_bytes, duration_ms = _wav_to_pcm16_mono_16k(wav_bytes)
|
|
record = _ensure_assistant_opener_audio(db, assistant)
|
|
record.enabled = True
|
|
record.file_path = _persist_opener_audio_pcm(assistant.id, pcm_bytes)
|
|
record.encoding = "pcm_s16le"
|
|
record.sample_rate_hz = 16000
|
|
record.channels = 1
|
|
record.duration_ms = duration_ms
|
|
record.text_hash = hashlib.sha256(opener_text.encode("utf-8")).hexdigest()
|
|
record.tts_fingerprint = _tts_fingerprint(tts_cfg, opener_text)
|
|
now = datetime.utcnow()
|
|
if not record.created_at:
|
|
record.created_at = now
|
|
record.updated_at = now
|
|
assistant.updated_at = now
|
|
db.commit()
|
|
db.refresh(assistant)
|
|
return _opener_audio_out(assistant.opener_audio)
|
|
|
|
|
|
@router.put("/{id}")
|
|
def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_db)):
|
|
"""更新助手"""
|
|
_ensure_assistant_schema(db)
|
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
|
if not assistant:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
opener_audio_enabled = update_data.pop("openerAudioEnabled", None)
|
|
if "manualOpenerToolCalls" in update_data:
|
|
update_data["manualOpenerToolCalls"] = _normalize_manual_opener_tool_calls(update_data.get("manualOpenerToolCalls"))
|
|
if "tools" in update_data:
|
|
update_data["tools"] = _normalize_assistant_tool_ids(update_data.get("tools"))
|
|
_apply_assistant_update(assistant, update_data)
|
|
if opener_audio_enabled is not None:
|
|
record = _ensure_assistant_opener_audio(db, assistant)
|
|
record.enabled = bool(opener_audio_enabled)
|
|
record.updated_at = datetime.utcnow()
|
|
|
|
assistant.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
db.refresh(assistant)
|
|
return assistant_to_dict(assistant)
|
|
|
|
|
|
@router.delete("/{id}")
|
|
def delete_assistant(id: str, db: Session = Depends(get_db)):
|
|
"""删除助手"""
|
|
_ensure_assistant_schema(db)
|
|
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
|
if not assistant:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
db.delete(assistant)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|