Add manual opener tool calls to Assistant model and API
- Introduced `manual_opener_tool_calls` field in the Assistant model to support custom tool calls. - Updated AssistantBase and AssistantUpdate schemas to include the new field. - Implemented normalization and migration logic for handling manual opener tool calls in the API. - Enhanced runtime metadata to include manual opener tool calls in responses. - Updated tests to validate the new functionality and ensure proper handling of tool calls. - Refactored tool ID normalization to support legacy tool names for backward compatibility.
This commit is contained in:
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy import inspect, text
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
@@ -27,6 +28,7 @@ from .tools import (
|
||||
TOOL_CATEGORY_MAP,
|
||||
TOOL_PARAMETER_DEFAULTS,
|
||||
TOOL_WAIT_FOR_RESPONSE_DEFAULTS,
|
||||
normalize_tool_id,
|
||||
_ensure_tool_resource_schema,
|
||||
)
|
||||
|
||||
@@ -111,9 +113,91 @@ def _compose_runtime_system_prompt(base_prompt: Optional[str]) -> str:
|
||||
return f"{raw}\n\n{tool_policy}" if raw else tool_policy
|
||||
|
||||
|
||||
def _ensure_assistant_schema(db: Session) -> None:
|
||||
"""Apply lightweight SQLite migrations for newly added assistants columns."""
|
||||
bind = db.get_bind()
|
||||
inspector = inspect(bind)
|
||||
try:
|
||||
columns = {col["name"] for col in inspector.get_columns("assistants")}
|
||||
except Exception:
|
||||
return
|
||||
|
||||
altered = False
|
||||
if "manual_opener_tool_calls" not in columns:
|
||||
db.execute(text("ALTER TABLE assistants ADD COLUMN manual_opener_tool_calls JSON"))
|
||||
altered = True
|
||||
|
||||
if altered:
|
||||
db.commit()
|
||||
|
||||
|
||||
def _normalize_manual_opener_tool_calls(raw: Any, warnings: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
if not isinstance(raw, list):
|
||||
return normalized
|
||||
|
||||
for idx, item in enumerate(raw):
|
||||
if not isinstance(item, dict):
|
||||
if warnings is not None:
|
||||
warnings.append(f"Ignored invalid manual opener tool call at index {idx}: not an object")
|
||||
continue
|
||||
|
||||
tool_name = normalize_tool_id(str(
|
||||
item.get("toolName")
|
||||
or item.get("tool_name")
|
||||
or item.get("name")
|
||||
or ""
|
||||
).strip())
|
||||
if not tool_name:
|
||||
if warnings is not None:
|
||||
warnings.append(f"Ignored invalid manual opener tool call at index {idx}: missing toolName")
|
||||
continue
|
||||
|
||||
args_raw = item.get("arguments")
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, dict):
|
||||
args = dict(args_raw)
|
||||
elif isinstance(args_raw, str):
|
||||
text_value = args_raw.strip()
|
||||
if text_value:
|
||||
try:
|
||||
parsed = json.loads(text_value)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
else:
|
||||
if warnings is not None:
|
||||
warnings.append(
|
||||
f"Ignored non-object arguments for manual opener tool call '{tool_name}' at index {idx}"
|
||||
)
|
||||
except Exception:
|
||||
if warnings is not None:
|
||||
warnings.append(f"Ignored invalid JSON arguments for manual opener tool call '{tool_name}' at index {idx}")
|
||||
elif args_raw is not None and warnings is not None:
|
||||
warnings.append(f"Ignored unsupported arguments type for manual opener tool call '{tool_name}' at index {idx}")
|
||||
|
||||
normalized.append({"toolName": tool_name, "arguments": args})
|
||||
|
||||
# Keep opener sequence intentionally short to avoid long pre-dialog delays.
|
||||
return normalized[:8]
|
||||
|
||||
|
||||
def _normalize_assistant_tool_ids(raw: Any) -> List[str]:
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
normalized: List[str] = []
|
||||
seen: set[str] = set()
|
||||
for item in raw:
|
||||
tool_id = normalize_tool_id(item)
|
||||
if not tool_id or tool_id in seen:
|
||||
continue
|
||||
seen.add(tool_id)
|
||||
normalized.append(tool_id)
|
||||
return normalized
|
||||
|
||||
|
||||
def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings: List[str]) -> List[Dict[str, Any]]:
|
||||
_ensure_tool_resource_schema(db)
|
||||
ids = [str(tool_id).strip() for tool_id in selected_tool_ids if str(tool_id).strip()]
|
||||
ids = _normalize_assistant_tool_ids(selected_tool_ids)
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
@@ -183,12 +267,17 @@ def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings:
|
||||
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]:
|
||||
warnings: List[str] = []
|
||||
generated_opener_enabled = bool(assistant.generated_opener_enabled)
|
||||
manual_opener_tool_calls = _normalize_manual_opener_tool_calls(
|
||||
assistant.manual_opener_tool_calls,
|
||||
warnings=warnings,
|
||||
)
|
||||
metadata: Dict[str, Any] = {
|
||||
"systemPrompt": _compose_runtime_system_prompt(assistant.prompt),
|
||||
"firstTurnMode": assistant.first_turn_mode or "bot_first",
|
||||
# Generated opener should rely on systemPrompt instead of fixed opener text.
|
||||
"greeting": "" if generated_opener_enabled else (assistant.opener or ""),
|
||||
"generatedOpenerEnabled": generated_opener_enabled,
|
||||
"manualOpenerToolCalls": manual_opener_tool_calls,
|
||||
"output": {"mode": "audio" if assistant.voice_output_enabled else "text"},
|
||||
"bargeIn": {
|
||||
"enabled": not bool(assistant.bot_cannot_be_interrupted),
|
||||
@@ -329,6 +418,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
"callCount": assistant.call_count,
|
||||
"firstTurnMode": assistant.first_turn_mode or "bot_first",
|
||||
"opener": assistant.opener or "",
|
||||
"manualOpenerToolCalls": _normalize_manual_opener_tool_calls(assistant.manual_opener_tool_calls),
|
||||
"generatedOpenerEnabled": bool(assistant.generated_opener_enabled),
|
||||
"openerAudioEnabled": bool(opener_audio.enabled) if opener_audio else False,
|
||||
"openerAudioReady": opener_audio_ready,
|
||||
@@ -341,7 +431,7 @@ def assistant_to_dict(assistant: Assistant) -> dict:
|
||||
"voice": assistant.voice,
|
||||
"speed": assistant.speed,
|
||||
"hotwords": assistant.hotwords or [],
|
||||
"tools": assistant.tools or [],
|
||||
"tools": _normalize_assistant_tool_ids(assistant.tools),
|
||||
"botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted),
|
||||
"interruptionSensitivity": assistant.interruption_sensitivity,
|
||||
"configMode": assistant.config_mode,
|
||||
@@ -360,6 +450,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None:
|
||||
field_map = {
|
||||
"knowledgeBaseId": "knowledge_base_id",
|
||||
"firstTurnMode": "first_turn_mode",
|
||||
"manualOpenerToolCalls": "manual_opener_tool_calls",
|
||||
"interruptionSensitivity": "interruption_sensitivity",
|
||||
"botCannotBeInterrupted": "bot_cannot_be_interrupted",
|
||||
"configMode": "config_mode",
|
||||
@@ -492,6 +583,7 @@ def list_assistants(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取助手列表"""
|
||||
_ensure_assistant_schema(db)
|
||||
query = db.query(Assistant)
|
||||
total = query.count()
|
||||
assistants = query.order_by(Assistant.created_at.desc()) \
|
||||
@@ -507,6 +599,7 @@ def list_assistants(
|
||||
@router.get("/{id}", response_model=AssistantOut)
|
||||
def get_assistant(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个助手详情"""
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
@@ -516,6 +609,7 @@ def get_assistant(id: str, db: Session = Depends(get_db)):
|
||||
@router.get("/{id}/config", response_model=AssistantEngineConfigResponse)
|
||||
def get_assistant_config(id: str, db: Session = Depends(get_db)):
|
||||
"""Canonical engine config endpoint consumed by engine backend adapter."""
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
@@ -525,6 +619,7 @@ def get_assistant_config(id: str, db: Session = Depends(get_db)):
|
||||
@router.get("/{id}/runtime-config", response_model=AssistantEngineConfigResponse)
|
||||
def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
|
||||
"""Legacy alias for resolved engine runtime config."""
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
@@ -534,12 +629,14 @@ def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)):
|
||||
@router.post("", response_model=AssistantOut)
|
||||
def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
"""创建新助手"""
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = Assistant(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
user_id=1, # 默认用户,后续添加认证
|
||||
name=data.name,
|
||||
first_turn_mode=data.firstTurnMode,
|
||||
opener=data.opener,
|
||||
manual_opener_tool_calls=_normalize_manual_opener_tool_calls(data.manualOpenerToolCalls),
|
||||
generated_opener_enabled=data.generatedOpenerEnabled,
|
||||
prompt=data.prompt,
|
||||
knowledge_base_id=data.knowledgeBaseId,
|
||||
@@ -548,7 +645,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
voice=data.voice,
|
||||
speed=data.speed,
|
||||
hotwords=data.hotwords,
|
||||
tools=data.tools,
|
||||
tools=_normalize_assistant_tool_ids(data.tools),
|
||||
bot_cannot_be_interrupted=data.botCannotBeInterrupted,
|
||||
interruption_sensitivity=data.interruptionSensitivity,
|
||||
config_mode=data.configMode,
|
||||
@@ -572,6 +669,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)):
|
||||
|
||||
@router.get("/{id}/opener-audio", response_model=AssistantOpenerAudioOut)
|
||||
def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)):
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
@@ -580,6 +678,7 @@ def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)):
|
||||
|
||||
@router.get("/{id}/opener-audio/pcm")
|
||||
def get_assistant_opener_audio_pcm(id: str, db: Session = Depends(get_db)):
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
@@ -602,6 +701,7 @@ def generate_assistant_opener_audio(
|
||||
data: AssistantOpenerAudioGenerateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
@@ -691,12 +791,17 @@ def generate_assistant_opener_audio(
|
||||
@router.put("/{id}")
|
||||
def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_db)):
|
||||
"""更新助手"""
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
opener_audio_enabled = update_data.pop("openerAudioEnabled", None)
|
||||
if "manualOpenerToolCalls" in update_data:
|
||||
update_data["manualOpenerToolCalls"] = _normalize_manual_opener_tool_calls(update_data.get("manualOpenerToolCalls"))
|
||||
if "tools" in update_data:
|
||||
update_data["tools"] = _normalize_assistant_tool_ids(update_data.get("tools"))
|
||||
_apply_assistant_update(assistant, update_data)
|
||||
if opener_audio_enabled is not None:
|
||||
record = _ensure_assistant_opener_audio(db, assistant)
|
||||
@@ -712,6 +817,7 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d
|
||||
@router.delete("/{id}")
|
||||
def delete_assistant(id: str, db: Session = Depends(get_db)):
|
||||
"""删除助手"""
|
||||
_ensure_assistant_schema(db)
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
@@ -14,6 +14,19 @@ from ..schemas import ToolResourceCreate, ToolResourceOut, ToolResourceUpdate
|
||||
router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
|
||||
|
||||
|
||||
TOOL_ID_ALIASES: Dict[str, str] = {
|
||||
# legacy -> canonical
|
||||
"voice_message_prompt": "voice_msg_prompt",
|
||||
}
|
||||
|
||||
|
||||
def normalize_tool_id(tool_id: Optional[str]) -> str:
|
||||
raw = str(tool_id or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
return TOOL_ID_ALIASES.get(raw, raw)
|
||||
|
||||
|
||||
# ============ Available Tools ============
|
||||
TOOL_REGISTRY = {
|
||||
"calculator": {
|
||||
@@ -87,7 +100,7 @@ TOOL_REGISTRY = {
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
"voice_message_prompt": {
|
||||
"voice_msg_prompt": {
|
||||
"name": "语音消息提示",
|
||||
"description": "播报一条语音提示消息",
|
||||
"parameters": {
|
||||
@@ -180,7 +193,8 @@ TOOL_CATEGORY_MAP = {
|
||||
"turn_off_camera": "system",
|
||||
"increase_volume": "system",
|
||||
"decrease_volume": "system",
|
||||
"voice_message_prompt": "system",
|
||||
"voice_msg_prompt": "system",
|
||||
"voice_message_prompt": "system", # backward compatibility
|
||||
"text_msg_prompt": "system",
|
||||
"voice_choice_prompt": "system",
|
||||
"text_choice_prompt": "system",
|
||||
@@ -194,7 +208,8 @@ TOOL_ICON_MAP = {
|
||||
"turn_off_camera": "CameraOff",
|
||||
"increase_volume": "Volume2",
|
||||
"decrease_volume": "Volume2",
|
||||
"voice_message_prompt": "Volume2",
|
||||
"voice_msg_prompt": "Volume2",
|
||||
"voice_message_prompt": "Volume2", # backward compatibility
|
||||
"text_msg_prompt": "Terminal",
|
||||
"voice_choice_prompt": "Volume2",
|
||||
"text_choice_prompt": "Terminal",
|
||||
@@ -284,9 +299,49 @@ def _validate_query_http_config(*, category: str, tool_id: Optional[str], http_u
|
||||
raise HTTPException(status_code=400, detail="http_url is required for query tools (except calculator/code_interpreter)")
|
||||
|
||||
|
||||
def _migrate_legacy_system_tool_ids(db: Session) -> None:
|
||||
"""Rename legacy built-in system tool IDs to their canonical IDs."""
|
||||
changed = False
|
||||
for legacy_id, canonical_id in TOOL_ID_ALIASES.items():
|
||||
if legacy_id == canonical_id:
|
||||
continue
|
||||
legacy_item = (
|
||||
db.query(ToolResource)
|
||||
.filter(ToolResource.id == legacy_id)
|
||||
.first()
|
||||
)
|
||||
if not legacy_item or not bool(legacy_item.is_system):
|
||||
continue
|
||||
|
||||
canonical_item = (
|
||||
db.query(ToolResource)
|
||||
.filter(ToolResource.id == canonical_id)
|
||||
.first()
|
||||
)
|
||||
if canonical_item:
|
||||
db.delete(legacy_item)
|
||||
changed = True
|
||||
continue
|
||||
|
||||
legacy_item.id = canonical_id
|
||||
legacy_item.updated_at = datetime.utcnow()
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
db.commit()
|
||||
|
||||
|
||||
def _seed_default_tools_if_empty(db: Session) -> None:
|
||||
"""Ensure built-in tools exist in tool_resources without overriding custom edits."""
|
||||
_ensure_tool_resource_schema(db)
|
||||
_migrate_legacy_system_tool_ids(db)
|
||||
existing_system_count = (
|
||||
db.query(ToolResource.id)
|
||||
.filter(ToolResource.is_system.is_(True))
|
||||
.count()
|
||||
)
|
||||
if existing_system_count > 0:
|
||||
return
|
||||
existing_ids = {
|
||||
str(item[0])
|
||||
for item in db.query(ToolResource.id).all()
|
||||
@@ -335,9 +390,10 @@ def list_available_tools():
|
||||
@router.get("/list/{tool_id}")
|
||||
def get_tool_detail(tool_id: str):
|
||||
"""获取工具详情"""
|
||||
if tool_id not in TOOL_REGISTRY:
|
||||
canonical_tool_id = normalize_tool_id(tool_id)
|
||||
if canonical_tool_id not in TOOL_REGISTRY:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
return TOOL_REGISTRY[tool_id]
|
||||
return TOOL_REGISTRY[canonical_tool_id]
|
||||
|
||||
|
||||
# ============ Tool Resource CRUD ============
|
||||
@@ -369,6 +425,10 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个工具资源详情。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item:
|
||||
canonical_id = normalize_tool_id(id)
|
||||
if canonical_id and canonical_id != id:
|
||||
item = db.query(ToolResource).filter(ToolResource.id == canonical_id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
return item
|
||||
@@ -378,7 +438,7 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
|
||||
"""创建自定义工具资源。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
candidate_id = (data.id or "").strip()
|
||||
candidate_id = normalize_tool_id((data.id or "").strip())
|
||||
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
|
||||
raise HTTPException(status_code=400, detail="Tool ID already exists")
|
||||
|
||||
@@ -413,7 +473,10 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
|
||||
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
|
||||
"""更新工具资源。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
canonical_id = normalize_tool_id(id)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item and canonical_id and canonical_id != id:
|
||||
item = db.query(ToolResource).filter(ToolResource.id == canonical_id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
|
||||
@@ -421,14 +484,14 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
|
||||
|
||||
new_category = update_data.get("category", item.category)
|
||||
new_http_url = update_data.get("http_url", item.http_url)
|
||||
_validate_query_http_config(category=new_category, tool_id=id, http_url=new_http_url)
|
||||
_validate_query_http_config(category=new_category, tool_id=item.id, http_url=new_http_url)
|
||||
|
||||
if "http_method" in update_data:
|
||||
update_data["http_method"] = _normalize_http_method(update_data.get("http_method"))
|
||||
if "http_timeout_ms" in update_data and update_data.get("http_timeout_ms") is not None:
|
||||
update_data["http_timeout_ms"] = max(1000, int(update_data["http_timeout_ms"]))
|
||||
if "parameter_schema" in update_data:
|
||||
update_data["parameter_schema"] = _normalize_parameter_schema(update_data.get("parameter_schema"), tool_id=id)
|
||||
update_data["parameter_schema"] = _normalize_parameter_schema(update_data.get("parameter_schema"), tool_id=item.id)
|
||||
if "parameter_defaults" in update_data:
|
||||
update_data["parameter_defaults"] = _normalize_parameter_defaults(update_data.get("parameter_defaults"))
|
||||
if new_category != "system":
|
||||
@@ -447,7 +510,10 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
|
||||
def delete_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
"""删除工具资源。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
canonical_id = normalize_tool_id(id)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item and canonical_id and canonical_id != id:
|
||||
item = db.query(ToolResource).filter(ToolResource.id == canonical_id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
db.delete(item)
|
||||
|
||||
Reference in New Issue
Block a user