Add manual opener tool calls to Assistant model and API

- Introduced `manual_opener_tool_calls` field in the Assistant model to support custom tool calls.
- Updated AssistantBase and AssistantUpdate schemas to include the new field.
- Implemented normalization and migration logic for handling manual opener tool calls in the API.
- Enhanced runtime metadata to include manual opener tool calls in responses.
- Updated tests to validate the new functionality and ensure proper handling of tool calls.
- Refactored tool ID normalization to support legacy tool names for backward compatibility.
This commit is contained in:
Xin Wang
2026-03-02 12:34:42 +08:00
parent b5cdb76e52
commit 00b88c5afa
14 changed files with 806 additions and 74 deletions

View File

@@ -7,6 +7,7 @@ from pathlib import Path
import httpx
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from sqlalchemy import inspect, text
from sqlalchemy.orm import Session
from typing import Any, Dict, List, Optional
import uuid
@@ -27,6 +28,7 @@ from .tools import (
TOOL_CATEGORY_MAP,
TOOL_PARAMETER_DEFAULTS,
TOOL_WAIT_FOR_RESPONSE_DEFAULTS,
normalize_tool_id,
_ensure_tool_resource_schema,
)
@@ -111,9 +113,91 @@ def _compose_runtime_system_prompt(base_prompt: Optional[str]) -> str:
return f"{raw}\n\n{tool_policy}" if raw else tool_policy
def _ensure_assistant_schema(db: Session) -> None:
"""Apply lightweight SQLite migrations for newly added assistants columns."""
bind = db.get_bind()
inspector = inspect(bind)
try:
columns = {col["name"] for col in inspector.get_columns("assistants")}
except Exception:
return
altered = False
if "manual_opener_tool_calls" not in columns:
db.execute(text("ALTER TABLE assistants ADD COLUMN manual_opener_tool_calls JSON"))
altered = True
if altered:
db.commit()
def _normalize_manual_opener_tool_calls(raw: Any, warnings: Optional[List[str]] = None) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
if not isinstance(raw, list):
return normalized
for idx, item in enumerate(raw):
if not isinstance(item, dict):
if warnings is not None:
warnings.append(f"Ignored invalid manual opener tool call at index {idx}: not an object")
continue
tool_name = normalize_tool_id(str(
item.get("toolName")
or item.get("tool_name")
or item.get("name")
or ""
).strip())
if not tool_name:
if warnings is not None:
warnings.append(f"Ignored invalid manual opener tool call at index {idx}: missing toolName")
continue
args_raw = item.get("arguments")
args: Dict[str, Any] = {}
if isinstance(args_raw, dict):
args = dict(args_raw)
elif isinstance(args_raw, str):
text_value = args_raw.strip()
if text_value:
try:
parsed = json.loads(text_value)
if isinstance(parsed, dict):
args = parsed
else:
if warnings is not None:
warnings.append(
f"Ignored non-object arguments for manual opener tool call '{tool_name}' at index {idx}"
)
except Exception:
if warnings is not None:
warnings.append(f"Ignored invalid JSON arguments for manual opener tool call '{tool_name}' at index {idx}")
elif args_raw is not None and warnings is not None:
warnings.append(f"Ignored unsupported arguments type for manual opener tool call '{tool_name}' at index {idx}")
normalized.append({"toolName": tool_name, "arguments": args})
# Keep opener sequence intentionally short to avoid long pre-dialog delays.
return normalized[:8]
def _normalize_assistant_tool_ids(raw: Any) -> List[str]:
if not isinstance(raw, list):
return []
normalized: List[str] = []
seen: set[str] = set()
for item in raw:
tool_id = normalize_tool_id(item)
if not tool_id or tool_id in seen:
continue
seen.add(tool_id)
normalized.append(tool_id)
return normalized
def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings: List[str]) -> List[Dict[str, Any]]:
_ensure_tool_resource_schema(db)
ids = [str(tool_id).strip() for tool_id in selected_tool_ids if str(tool_id).strip()]
ids = _normalize_assistant_tool_ids(selected_tool_ids)
if not ids:
return []
@@ -183,12 +267,17 @@ def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings:
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]:
warnings: List[str] = []
generated_opener_enabled = bool(assistant.generated_opener_enabled)
manual_opener_tool_calls = _normalize_manual_opener_tool_calls(
assistant.manual_opener_tool_calls,
warnings=warnings,
)
metadata: Dict[str, Any] = {
"systemPrompt": _compose_runtime_system_prompt(assistant.prompt),
"firstTurnMode": assistant.first_turn_mode or "bot_first",
# Generated opener should rely on systemPrompt instead of fixed opener text.
"greeting": "" if generated_opener_enabled else (assistant.opener or ""),
"generatedOpenerEnabled": generated_opener_enabled,
"manualOpenerToolCalls": manual_opener_tool_calls,
"output": {"mode": "audio" if assistant.voice_output_enabled else "text"},
"bargeIn": {
"enabled": not bool(assistant.bot_cannot_be_interrupted),
@@ -329,6 +418,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
"callCount": assistant.call_count,
"firstTurnMode": assistant.first_turn_mode or "bot_first",
"opener": assistant.opener or "",
"manualOpenerToolCalls": _normalize_manual_opener_tool_calls(assistant.manual_opener_tool_calls),
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
"openerAudioEnabled": bool(opener_audio.enabled) if opener_audio else False,
"openerAudioReady": opener_audio_ready,
@@ -341,7 +431,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
"voice": assistant.voice,
"speed": assistant.speed,
"hotwords": assistant.hotwords or [],
"tools": assistant.tools or [],
"tools": _normalize_assistant_tool_ids(assistant.tools),
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
"interruptionSensitivity": assistant.interruption_sensitivity,
"configMode": assistant.config_mode,
@@ -360,6 +450,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
field_map = {
"knowledgeBaseId": "knowledge_base_id",
"firstTurnMode": "first_turn_mode",
"manualOpenerToolCalls": "manual_opener_tool_calls",
"interruptionSensitivity": "interruption_sensitivity",
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
"configMode": "config_mode",
@@ -492,6 +583,7 @@ def list_assistants(
db: Session = Depends(get_db)
):
"""获取助手列表"""
_ensure_assistant_schema(db)
query = db.query(Assistant)
total = query.count()
assistants = query.order_by(Assistant.created_at.desc()) \
@@ -507,6 +599,7 @@ def list_assistants(
@router.get("/{id}", response_model=AssistantOut)
def get_assistant(id: str, db: Session = Depends(get_db)):
"""获取单个助手详情"""
_ensure_assistant_schema(db)
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
@@ -516,6 +609,7 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
@router.get("/{id}/config", response_model=AssistantEngineConfigResponse)
def get_assistant_config(id: str, db: Session = Depends(get_db)):
"""Canonical engine config endpoint consumed by engine backend adapter."""
_ensure_assistant_schema(db)
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
@@ -525,6 +619,7 @@ def get_assistant_config(id: str, db: Session = Depends(get_db)):
@router.get("/{id}/runtime-config", response_model=AssistantEngineConfigResponse)
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
"""Legacy alias for resolved engine runtime config."""
_ensure_assistant_schema(db)
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
@@ -534,12 +629,14 @@ def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
@router.post("", response_model=AssistantOut)
def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
"""创建新助手"""
_ensure_assistant_schema(db)
assistant = Assistant(
id=str(uuid.uuid4())[:8],
user_id=1, # 默认用户,后续添加认证
name=data.name,
first_turn_mode=data.firstTurnMode,
opener=data.opener,
manual_opener_tool_calls=_normalize_manual_opener_tool_calls(data.manualOpenerToolCalls),
generated_opener_enabled=data.generatedOpenerEnabled,
prompt=data.prompt,
knowledge_base_id=data.knowledgeBaseId,
@@ -548,7 +645,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
voice=data.voice,
speed=data.speed,
hotwords=data.hotwords,
tools=data.tools,
tools=_normalize_assistant_tool_ids(data.tools),
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
interruption_sensitivity=data.interruptionSensitivity,
config_mode=data.configMode,
@@ -572,6 +669,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
@router.get("/{id}/opener-audio", response_model=AssistantOpenerAudioOut)
def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)):
_ensure_assistant_schema(db)
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
@@ -580,6 +678,7 @@ def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)):
@router.get("/{id}/opener-audio/pcm")
def get_assistant_opener_audio_pcm(id: str, db: Session = Depends(get_db)):
_ensure_assistant_schema(db)
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
@@ -602,6 +701,7 @@ def generate_assistant_opener_audio(
data: AssistantOpenerAudioGenerateRequest,
db: Session = Depends(get_db),
):
_ensure_assistant_schema(db)
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
@@ -691,12 +791,17 @@ def generate_assistant_opener_audio(
@router.put("/{id}")
def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_db)):
"""更新助手"""
_ensure_assistant_schema(db)
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
update_data = data.model_dump(exclude_unset=True)
opener_audio_enabled = update_data.pop("openerAudioEnabled", None)
if "manualOpenerToolCalls" in update_data:
update_data["manualOpenerToolCalls"] = _normalize_manual_opener_tool_calls(update_data.get("manualOpenerToolCalls"))
if "tools" in update_data:
update_data["tools"] = _normalize_assistant_tool_ids(update_data.get("tools"))
_apply_assistant_update(assistant, update_data)
if opener_audio_enabled is not None:
record = _ensure_assistant_opener_audio(db, assistant)
@@ -712,6 +817,7 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
@router.delete("/{id}")
def delete_assistant(id: str, db: Session = Depends(get_db)):
"""删除助手"""
_ensure_assistant_schema(db)
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")