Add parameter schema and defaults to ToolResource model and schemas. Implement runtime tool resolution in assistants and tools routers, ensuring proper handling of tool parameters. Update tests to validate new functionality and ensure correct integration of parameter handling in the API.

This commit is contained in:
Xin Wang
2026-02-27 14:44:28 +08:00
parent d942c85eff
commit 5f768edf68
13 changed files with 397 additions and 9 deletions

View File

@@ -13,7 +13,7 @@ import uuid
from datetime import datetime
from ..db import get_db
from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice
from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice, ToolResource
from ..schemas import (
AssistantCreate,
AssistantUpdate,
@@ -22,6 +22,7 @@ from ..schemas import (
AssistantOpenerAudioGenerateRequest,
AssistantOpenerAudioOut,
)
from .tools import TOOL_REGISTRY, TOOL_CATEGORY_MAP, TOOL_PARAMETER_DEFAULTS, _ensure_tool_resource_schema
router = APIRouter(prefix="/assistants", tags=["Assistants"])
@@ -78,7 +79,75 @@ def _config_version_id(assistant: Assistant) -> str:
return f"asst_{assistant.id}_{updated.strftime('%Y%m%d%H%M%S')}"
def _normalize_runtime_tool_schema(tool_id: str, raw_schema: Any) -> Dict[str, Any]:
schema = dict(raw_schema) if isinstance(raw_schema, dict) else {}
if not schema:
fallback = TOOL_REGISTRY.get(tool_id, {}).get("parameters")
if isinstance(fallback, dict):
schema = dict(fallback)
schema.setdefault("type", "object")
if not isinstance(schema.get("properties"), dict):
schema["properties"] = {}
required = schema.get("required")
if required is None or not isinstance(required, list):
schema["required"] = []
return schema
def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings: List[str]) -> List[Dict[str, Any]]:
_ensure_tool_resource_schema(db)
ids = [str(tool_id).strip() for tool_id in selected_tool_ids if str(tool_id).strip()]
if not ids:
return []
resources = (
db.query(ToolResource)
.filter(ToolResource.id.in_(ids))
.all()
)
by_id = {str(item.id): item for item in resources}
runtime_tools: List[Dict[str, Any]] = []
for tool_id in ids:
resource = by_id.get(tool_id)
if resource and resource.enabled is False:
warnings.append(f"Tool is disabled and skipped in runtime config: {tool_id}")
continue
category = str(resource.category if resource else TOOL_CATEGORY_MAP.get(tool_id, "query"))
description = (
str(resource.description or resource.name or "").strip()
if resource
else str(TOOL_REGISTRY.get(tool_id, {}).get("description") or "").strip()
)
schema = _normalize_runtime_tool_schema(
tool_id,
resource.parameter_schema if resource else TOOL_REGISTRY.get(tool_id, {}).get("parameters"),
)
defaults_raw = resource.parameter_defaults if resource else TOOL_PARAMETER_DEFAULTS.get(tool_id)
defaults = dict(defaults_raw) if isinstance(defaults_raw, dict) else {}
if not resource and tool_id not in TOOL_REGISTRY:
warnings.append(f"Tool resource not found: {tool_id}")
runtime_tool: Dict[str, Any] = {
"type": "function",
"executor": "client" if category == "system" else "server",
"function": {
"name": tool_id,
"description": description or tool_id,
"parameters": schema,
},
}
if defaults:
runtime_tool["defaultArgs"] = defaults
runtime_tools.append(runtime_tool)
return runtime_tools
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]:
warnings: List[str] = []
metadata: Dict[str, Any] = {
"systemPrompt": assistant.prompt or "",
"firstTurnMode": assistant.first_turn_mode or "bot_first",
@@ -90,14 +159,13 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
"minDurationMs": int(assistant.interruption_sensitivity or 500),
},
"services": {},
"tools": assistant.tools or [],
"tools": _resolve_runtime_tools(db, assistant.tools or [], warnings),
"history": {
"assistantId": assistant.id,
"userId": int(assistant.user_id or 1),
"source": "debug",
},
}
warnings: List[str] = []
config_mode = str(assistant.config_mode or "platform").strip().lower()

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import inspect, text
from typing import Optional, Dict, Any, List
import time
import uuid
@@ -111,6 +112,61 @@ TOOL_ICON_MAP = {
TOOL_HTTP_DEFAULTS = {
}
TOOL_PARAMETER_DEFAULTS = {
"increase_volume": {"step": 1},
"decrease_volume": {"step": 1},
}
def _normalize_parameter_schema(value: Any, *, tool_id: Optional[str] = None) -> Dict[str, Any]:
if not isinstance(value, dict):
value = {}
normalized = dict(value)
if not normalized:
fallback = TOOL_REGISTRY.get(str(tool_id or "").strip(), {}).get("parameters")
if isinstance(fallback, dict):
normalized = dict(fallback)
normalized.setdefault("type", "object")
if normalized.get("type") != "object":
raise HTTPException(status_code=400, detail="parameter_schema.type must be 'object'")
properties = normalized.get("properties")
if not isinstance(properties, dict):
normalized["properties"] = {}
required = normalized.get("required")
if required is None:
normalized["required"] = []
elif not isinstance(required, list):
raise HTTPException(status_code=400, detail="parameter_schema.required must be an array")
return normalized
def _normalize_parameter_defaults(value: Any) -> Dict[str, Any]:
if value is None:
return {}
if not isinstance(value, dict):
raise HTTPException(status_code=400, detail="parameter_defaults must be an object")
return dict(value)
def _ensure_tool_resource_schema(db: Session) -> None:
"""Apply lightweight SQLite migrations for newly added tool_resources columns."""
bind = db.get_bind()
inspector = inspect(bind)
try:
columns = {col["name"] for col in inspector.get_columns("tool_resources")}
except Exception:
return
altered = False
if "parameter_schema" not in columns:
db.execute(text("ALTER TABLE tool_resources ADD COLUMN parameter_schema JSON"))
altered = True
if "parameter_defaults" not in columns:
db.execute(text("ALTER TABLE tool_resources ADD COLUMN parameter_defaults JSON"))
altered = True
if altered:
db.commit()
def _normalize_http_method(method: Optional[str]) -> str:
normalized = str(method or "GET").strip().upper()
@@ -127,8 +183,10 @@ def _validate_query_http_config(*, category: str, tool_id: Optional[str], http_u
if _requires_http_request(category, tool_id) and not str(http_url or "").strip():
raise HTTPException(status_code=400, detail="http_url is required for query tools (except calculator/code_interpreter)")
def _seed_default_tools_if_empty(db: Session) -> None:
"""Seed built-in tools only when tool_resources is empty."""
_ensure_tool_resource_schema(db)
if db.query(ToolResource).count() > 0:
return
for tool_id, payload in TOOL_REGISTRY.items():
@@ -144,6 +202,8 @@ def _seed_default_tools_if_empty(db: Session) -> None:
http_url=http_defaults.get("http_url"),
http_headers=http_defaults.get("http_headers") or {},
http_timeout_ms=int(http_defaults.get("http_timeout_ms") or 10000),
parameter_schema=_normalize_parameter_schema(payload.get("parameters"), tool_id=tool_id),
parameter_defaults=_normalize_parameter_defaults(TOOL_PARAMETER_DEFAULTS.get(tool_id)),
enabled=True,
is_system=True,
))
@@ -215,6 +275,8 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
raise HTTPException(status_code=400, detail="Tool ID already exists")
_validate_query_http_config(category=data.category, tool_id=candidate_id, http_url=data.http_url)
parameter_schema = _normalize_parameter_schema(data.parameter_schema, tool_id=candidate_id)
parameter_defaults = _normalize_parameter_defaults(data.parameter_defaults)
item = ToolResource(
id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}",
@@ -227,6 +289,8 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
http_url=(data.http_url or "").strip() or None,
http_headers=data.http_headers or {},
http_timeout_ms=max(1000, int(data.http_timeout_ms or 10000)),
parameter_schema=parameter_schema,
parameter_defaults=parameter_defaults,
enabled=data.enabled,
is_system=False,
)
@@ -254,6 +318,10 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
update_data["http_method"] = _normalize_http_method(update_data.get("http_method"))
if "http_timeout_ms" in update_data and update_data.get("http_timeout_ms") is not None:
update_data["http_timeout_ms"] = max(1000, int(update_data["http_timeout_ms"]))
if "parameter_schema" in update_data:
update_data["parameter_schema"] = _normalize_parameter_schema(update_data.get("parameter_schema"), tool_id=id)
if "parameter_defaults" in update_data:
update_data["parameter_defaults"] = _normalize_parameter_defaults(update_data.get("parameter_defaults"))
for field, value in update_data.items():
setattr(item, field, value)