Add parameter schema and defaults to ToolResource model and schemas. Implement runtime tool resolution in assistants and tools routers, ensuring proper handling of tool parameters. Update tests to validate new functionality and ensure correct integration of parameter handling in the API.
This commit is contained in:
@@ -96,6 +96,8 @@ class ToolResource(Base):
|
||||
http_url: Mapped[Optional[str]] = mapped_column(String(1024), nullable=True)
|
||||
http_headers: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
http_timeout_ms: Mapped[int] = mapped_column(Integer, default=10000)
|
||||
parameter_schema: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
parameter_defaults: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
enabled: Mapped[bool] = mapped_column(default=True)
|
||||
is_system: Mapped[bool] = mapped_column(default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@@ -13,7 +13,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice
|
||||
from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice, ToolResource
|
||||
from ..schemas import (
|
||||
AssistantCreate,
|
||||
AssistantUpdate,
|
||||
@@ -22,6 +22,7 @@ from ..schemas import (
|
||||
AssistantOpenerAudioGenerateRequest,
|
||||
AssistantOpenerAudioOut,
|
||||
)
|
||||
from .tools import TOOL_REGISTRY, TOOL_CATEGORY_MAP, TOOL_PARAMETER_DEFAULTS, _ensure_tool_resource_schema
|
||||
|
||||
router = APIRouter(prefix="/assistants", tags=["Assistants"])
|
||||
|
||||
@@ -78,7 +79,75 @@ def _config_version_id(assistant: Assistant) -> str:
|
||||
return f"asst_{assistant.id}_{updated.strftime('%Y%m%d%H%M%S')}"
|
||||
|
||||
|
||||
def _normalize_runtime_tool_schema(tool_id: str, raw_schema: Any) -> Dict[str, Any]:
|
||||
schema = dict(raw_schema) if isinstance(raw_schema, dict) else {}
|
||||
if not schema:
|
||||
fallback = TOOL_REGISTRY.get(tool_id, {}).get("parameters")
|
||||
if isinstance(fallback, dict):
|
||||
schema = dict(fallback)
|
||||
schema.setdefault("type", "object")
|
||||
if not isinstance(schema.get("properties"), dict):
|
||||
schema["properties"] = {}
|
||||
required = schema.get("required")
|
||||
if required is None or not isinstance(required, list):
|
||||
schema["required"] = []
|
||||
return schema
|
||||
|
||||
|
||||
def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings: List[str]) -> List[Dict[str, Any]]:
|
||||
_ensure_tool_resource_schema(db)
|
||||
ids = [str(tool_id).strip() for tool_id in selected_tool_ids if str(tool_id).strip()]
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
resources = (
|
||||
db.query(ToolResource)
|
||||
.filter(ToolResource.id.in_(ids))
|
||||
.all()
|
||||
)
|
||||
by_id = {str(item.id): item for item in resources}
|
||||
|
||||
runtime_tools: List[Dict[str, Any]] = []
|
||||
for tool_id in ids:
|
||||
resource = by_id.get(tool_id)
|
||||
if resource and resource.enabled is False:
|
||||
warnings.append(f"Tool is disabled and skipped in runtime config: {tool_id}")
|
||||
continue
|
||||
|
||||
category = str(resource.category if resource else TOOL_CATEGORY_MAP.get(tool_id, "query"))
|
||||
description = (
|
||||
str(resource.description or resource.name or "").strip()
|
||||
if resource
|
||||
else str(TOOL_REGISTRY.get(tool_id, {}).get("description") or "").strip()
|
||||
)
|
||||
schema = _normalize_runtime_tool_schema(
|
||||
tool_id,
|
||||
resource.parameter_schema if resource else TOOL_REGISTRY.get(tool_id, {}).get("parameters"),
|
||||
)
|
||||
defaults_raw = resource.parameter_defaults if resource else TOOL_PARAMETER_DEFAULTS.get(tool_id)
|
||||
defaults = dict(defaults_raw) if isinstance(defaults_raw, dict) else {}
|
||||
|
||||
if not resource and tool_id not in TOOL_REGISTRY:
|
||||
warnings.append(f"Tool resource not found: {tool_id}")
|
||||
|
||||
runtime_tool: Dict[str, Any] = {
|
||||
"type": "function",
|
||||
"executor": "client" if category == "system" else "server",
|
||||
"function": {
|
||||
"name": tool_id,
|
||||
"description": description or tool_id,
|
||||
"parameters": schema,
|
||||
},
|
||||
}
|
||||
if defaults:
|
||||
runtime_tool["defaultArgs"] = defaults
|
||||
runtime_tools.append(runtime_tool)
|
||||
|
||||
return runtime_tools
|
||||
|
||||
|
||||
def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]:
|
||||
warnings: List[str] = []
|
||||
metadata: Dict[str, Any] = {
|
||||
"systemPrompt": assistant.prompt or "",
|
||||
"firstTurnMode": assistant.first_turn_mode or "bot_first",
|
||||
@@ -90,14 +159,13 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s
|
||||
"minDurationMs": int(assistant.interruption_sensitivity or 500),
|
||||
},
|
||||
"services": {},
|
||||
"tools": assistant.tools or [],
|
||||
"tools": _resolve_runtime_tools(db, assistant.tools or [], warnings),
|
||||
"history": {
|
||||
"assistantId": assistant.id,
|
||||
"userId": int(assistant.user_id or 1),
|
||||
"source": "debug",
|
||||
},
|
||||
}
|
||||
warnings: List[str] = []
|
||||
|
||||
config_mode = str(assistant.config_mode or "platform").strip().lower()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import inspect, text
|
||||
from typing import Optional, Dict, Any, List
|
||||
import time
|
||||
import uuid
|
||||
@@ -111,6 +112,61 @@ TOOL_ICON_MAP = {
|
||||
TOOL_HTTP_DEFAULTS = {
|
||||
}
|
||||
|
||||
TOOL_PARAMETER_DEFAULTS = {
|
||||
"increase_volume": {"step": 1},
|
||||
"decrease_volume": {"step": 1},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_parameter_schema(value: Any, *, tool_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
if not isinstance(value, dict):
|
||||
value = {}
|
||||
normalized = dict(value)
|
||||
if not normalized:
|
||||
fallback = TOOL_REGISTRY.get(str(tool_id or "").strip(), {}).get("parameters")
|
||||
if isinstance(fallback, dict):
|
||||
normalized = dict(fallback)
|
||||
normalized.setdefault("type", "object")
|
||||
if normalized.get("type") != "object":
|
||||
raise HTTPException(status_code=400, detail="parameter_schema.type must be 'object'")
|
||||
properties = normalized.get("properties")
|
||||
if not isinstance(properties, dict):
|
||||
normalized["properties"] = {}
|
||||
required = normalized.get("required")
|
||||
if required is None:
|
||||
normalized["required"] = []
|
||||
elif not isinstance(required, list):
|
||||
raise HTTPException(status_code=400, detail="parameter_schema.required must be an array")
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_parameter_defaults(value: Any) -> Dict[str, Any]:
|
||||
if value is None:
|
||||
return {}
|
||||
if not isinstance(value, dict):
|
||||
raise HTTPException(status_code=400, detail="parameter_defaults must be an object")
|
||||
return dict(value)
|
||||
|
||||
|
||||
def _ensure_tool_resource_schema(db: Session) -> None:
|
||||
"""Apply lightweight SQLite migrations for newly added tool_resources columns."""
|
||||
bind = db.get_bind()
|
||||
inspector = inspect(bind)
|
||||
try:
|
||||
columns = {col["name"] for col in inspector.get_columns("tool_resources")}
|
||||
except Exception:
|
||||
return
|
||||
|
||||
altered = False
|
||||
if "parameter_schema" not in columns:
|
||||
db.execute(text("ALTER TABLE tool_resources ADD COLUMN parameter_schema JSON"))
|
||||
altered = True
|
||||
if "parameter_defaults" not in columns:
|
||||
db.execute(text("ALTER TABLE tool_resources ADD COLUMN parameter_defaults JSON"))
|
||||
altered = True
|
||||
if altered:
|
||||
db.commit()
|
||||
|
||||
|
||||
def _normalize_http_method(method: Optional[str]) -> str:
|
||||
normalized = str(method or "GET").strip().upper()
|
||||
@@ -127,8 +183,10 @@ def _validate_query_http_config(*, category: str, tool_id: Optional[str], http_u
|
||||
if _requires_http_request(category, tool_id) and not str(http_url or "").strip():
|
||||
raise HTTPException(status_code=400, detail="http_url is required for query tools (except calculator/code_interpreter)")
|
||||
|
||||
|
||||
def _seed_default_tools_if_empty(db: Session) -> None:
|
||||
"""Seed built-in tools only when tool_resources is empty."""
|
||||
_ensure_tool_resource_schema(db)
|
||||
if db.query(ToolResource).count() > 0:
|
||||
return
|
||||
for tool_id, payload in TOOL_REGISTRY.items():
|
||||
@@ -144,6 +202,8 @@ def _seed_default_tools_if_empty(db: Session) -> None:
|
||||
http_url=http_defaults.get("http_url"),
|
||||
http_headers=http_defaults.get("http_headers") or {},
|
||||
http_timeout_ms=int(http_defaults.get("http_timeout_ms") or 10000),
|
||||
parameter_schema=_normalize_parameter_schema(payload.get("parameters"), tool_id=tool_id),
|
||||
parameter_defaults=_normalize_parameter_defaults(TOOL_PARAMETER_DEFAULTS.get(tool_id)),
|
||||
enabled=True,
|
||||
is_system=True,
|
||||
))
|
||||
@@ -215,6 +275,8 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
|
||||
raise HTTPException(status_code=400, detail="Tool ID already exists")
|
||||
|
||||
_validate_query_http_config(category=data.category, tool_id=candidate_id, http_url=data.http_url)
|
||||
parameter_schema = _normalize_parameter_schema(data.parameter_schema, tool_id=candidate_id)
|
||||
parameter_defaults = _normalize_parameter_defaults(data.parameter_defaults)
|
||||
|
||||
item = ToolResource(
|
||||
id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}",
|
||||
@@ -227,6 +289,8 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
|
||||
http_url=(data.http_url or "").strip() or None,
|
||||
http_headers=data.http_headers or {},
|
||||
http_timeout_ms=max(1000, int(data.http_timeout_ms or 10000)),
|
||||
parameter_schema=parameter_schema,
|
||||
parameter_defaults=parameter_defaults,
|
||||
enabled=data.enabled,
|
||||
is_system=False,
|
||||
)
|
||||
@@ -254,6 +318,10 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
|
||||
update_data["http_method"] = _normalize_http_method(update_data.get("http_method"))
|
||||
if "http_timeout_ms" in update_data and update_data.get("http_timeout_ms") is not None:
|
||||
update_data["http_timeout_ms"] = max(1000, int(update_data["http_timeout_ms"]))
|
||||
if "parameter_schema" in update_data:
|
||||
update_data["parameter_schema"] = _normalize_parameter_schema(update_data.get("parameter_schema"), tool_id=id)
|
||||
if "parameter_defaults" in update_data:
|
||||
update_data["parameter_defaults"] = _normalize_parameter_defaults(update_data.get("parameter_defaults"))
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(item, field, value)
|
||||
|
||||
@@ -239,6 +239,8 @@ class ToolResourceBase(BaseModel):
|
||||
http_url: Optional[str] = None
|
||||
http_headers: Dict[str, str] = Field(default_factory=dict)
|
||||
http_timeout_ms: int = 10000
|
||||
parameter_schema: Dict[str, Any] = Field(default_factory=dict)
|
||||
parameter_defaults: Dict[str, Any] = Field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
@@ -255,6 +257,8 @@ class ToolResourceUpdate(BaseModel):
|
||||
http_url: Optional[str] = None
|
||||
http_headers: Optional[Dict[str, str]] = None
|
||||
http_timeout_ms: Optional[int] = None
|
||||
parameter_schema: Optional[Dict[str, Any]] = None
|
||||
parameter_defaults: Optional[Dict[str, Any]] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
|
||||
@@ -243,6 +243,26 @@ class TestAssistantAPI:
|
||||
assert payload["sessionStartMetadata"]["systemPrompt"] == sample_assistant_data["prompt"]
|
||||
assert payload["sessionStartMetadata"]["history"]["assistantId"] == assistant_id
|
||||
|
||||
def test_runtime_config_resolves_selected_tools_into_runtime_definitions(self, client, sample_assistant_data):
|
||||
sample_assistant_data["tools"] = ["increase_volume", "calculator"]
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
assert assistant_resp.status_code == 200
|
||||
assistant_id = assistant_resp.json()["id"]
|
||||
|
||||
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
tools = metadata["tools"]
|
||||
assert isinstance(tools, list)
|
||||
assert len(tools) == 2
|
||||
|
||||
by_name = {item["function"]["name"]: item for item in tools}
|
||||
assert by_name["increase_volume"]["executor"] == "client"
|
||||
assert by_name["increase_volume"]["defaultArgs"]["step"] == 1
|
||||
assert by_name["calculator"]["executor"] == "server"
|
||||
assert by_name["calculator"]["function"]["parameters"]["type"] == "object"
|
||||
assert "expression" in by_name["calculator"]["function"]["parameters"]["properties"]
|
||||
|
||||
def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data):
|
||||
sample_assistant_data["voiceOutputEnabled"] = False
|
||||
assistant_resp = client.post("/api/assistants", json=sample_assistant_data)
|
||||
|
||||
@@ -281,6 +281,9 @@ class TestToolResourceCRUD:
|
||||
assert payload["total"] >= 1
|
||||
ids = [item["id"] for item in payload["list"]]
|
||||
assert "calculator" in ids
|
||||
calculator = next((item for item in payload["list"] if item["id"] == "calculator"), None)
|
||||
assert calculator is not None
|
||||
assert calculator["parameter_schema"]["type"] == "object"
|
||||
|
||||
def test_create_update_delete_tool_resource(self, client):
|
||||
create_resp = client.post("/api/tools/resources", json={
|
||||
@@ -292,6 +295,12 @@ class TestToolResourceCRUD:
|
||||
"http_url": "https://example.com/search",
|
||||
"http_headers": {},
|
||||
"http_timeout_ms": 10000,
|
||||
"parameter_schema": {
|
||||
"type": "object",
|
||||
"properties": {"keyword": {"type": "string"}},
|
||||
"required": ["keyword"]
|
||||
},
|
||||
"parameter_defaults": {"limit": 10},
|
||||
"enabled": True,
|
||||
})
|
||||
assert create_resp.status_code == 200
|
||||
@@ -299,15 +308,19 @@ class TestToolResourceCRUD:
|
||||
tool_id = created["id"]
|
||||
assert created["name"] == "自定义网页抓取"
|
||||
assert created["is_system"] is False
|
||||
assert created["parameter_schema"]["required"] == ["keyword"]
|
||||
assert created["parameter_defaults"]["limit"] == 10
|
||||
|
||||
update_resp = client.put(f"/api/tools/resources/{tool_id}", json={
|
||||
"name": "自定义网页检索",
|
||||
"category": "system",
|
||||
"parameter_defaults": {"limit": 20},
|
||||
})
|
||||
assert update_resp.status_code == 200
|
||||
updated = update_resp.json()
|
||||
assert updated["name"] == "自定义网页检索"
|
||||
assert updated["category"] == "system"
|
||||
assert updated["parameter_defaults"]["limit"] == 20
|
||||
|
||||
get_resp = client.get(f"/api/tools/resources/{tool_id}")
|
||||
assert get_resp.status_code == 200
|
||||
|
||||
@@ -287,6 +287,7 @@ class DuplexPipeline:
|
||||
raw_default_tools = settings.tools if isinstance(settings.tools, list) else []
|
||||
self._runtime_tools: List[Any] = list(raw_default_tools)
|
||||
self._runtime_tool_executor: Dict[str, str] = {}
|
||||
self._runtime_tool_default_args: Dict[str, Dict[str, Any]] = {}
|
||||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||
self._completed_tool_call_ids: set[str] = set()
|
||||
@@ -307,6 +308,7 @@ class DuplexPipeline:
|
||||
self._last_llm_delta_emit_ms: float = 0.0
|
||||
|
||||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||||
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
|
||||
self._initial_greeting_emitted = False
|
||||
|
||||
if self._server_tool_executor is None:
|
||||
@@ -408,9 +410,11 @@ class DuplexPipeline:
|
||||
if isinstance(tools_payload, list):
|
||||
self._runtime_tools = tools_payload
|
||||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||||
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
|
||||
elif "tools" in metadata:
|
||||
self._runtime_tools = []
|
||||
self._runtime_tool_executor = {}
|
||||
self._runtime_tool_default_args = {}
|
||||
|
||||
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||
@@ -1473,6 +1477,25 @@ class DuplexPipeline:
|
||||
result[name] = executor
|
||||
return result
|
||||
|
||||
def _resolved_tool_default_args_map(self) -> Dict[str, Dict[str, Any]]:
|
||||
result: Dict[str, Dict[str, Any]] = {}
|
||||
for item in self._runtime_tools:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
name = str(fn.get("name")).strip()
|
||||
else:
|
||||
name = str(item.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
raw_defaults = item.get("defaultArgs")
|
||||
if raw_defaults is None:
|
||||
raw_defaults = item.get("default_args")
|
||||
if isinstance(raw_defaults, dict):
|
||||
result[name] = dict(raw_defaults)
|
||||
return result
|
||||
|
||||
def _resolved_tool_allowlist(self) -> List[str]:
|
||||
names: set[str] = set()
|
||||
for item in self._runtime_tools:
|
||||
@@ -1518,6 +1541,15 @@ class DuplexPipeline:
|
||||
return {"raw": raw}
|
||||
return {}
|
||||
|
||||
def _apply_tool_default_args(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
defaults = self._runtime_tool_default_args.get(tool_name)
|
||||
if not isinstance(defaults, dict) or not defaults:
|
||||
return args
|
||||
merged = dict(defaults)
|
||||
if isinstance(args, dict):
|
||||
merged.update(args)
|
||||
return merged
|
||||
|
||||
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||||
status_code = int(status.get("code") or 0) if status else 0
|
||||
@@ -1702,15 +1734,27 @@ class DuplexPipeline:
|
||||
enriched_tool_call["executor"] = executor
|
||||
tool_name = self._tool_name(enriched_tool_call) or "unknown_tool"
|
||||
call_id = str(enriched_tool_call.get("id") or "").strip()
|
||||
fn_payload = enriched_tool_call.get("function")
|
||||
fn_payload = (
|
||||
dict(enriched_tool_call.get("function"))
|
||||
if isinstance(enriched_tool_call.get("function"), dict)
|
||||
else None
|
||||
)
|
||||
raw_args = str(fn_payload.get("arguments") or "") if isinstance(fn_payload, dict) else ""
|
||||
tool_arguments = self._tool_arguments(enriched_tool_call)
|
||||
merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments)
|
||||
try:
|
||||
merged_args_text = json.dumps(merged_tool_arguments, ensure_ascii=False)
|
||||
except Exception:
|
||||
merged_args_text = raw_args if raw_args else "{}"
|
||||
if isinstance(fn_payload, dict):
|
||||
fn_payload["arguments"] = merged_args_text
|
||||
enriched_tool_call["function"] = fn_payload
|
||||
args_preview = raw_args if len(raw_args) <= 160 else f"{raw_args[:160]}..."
|
||||
logger.info(
|
||||
f"[Tool] call requested name={tool_name} call_id={call_id} "
|
||||
f"executor={executor} args={args_preview}"
|
||||
f"executor={executor} args={args_preview} merged_args={merged_args_text}"
|
||||
)
|
||||
tool_calls.append(enriched_tool_call)
|
||||
tool_arguments = self._tool_arguments(enriched_tool_call)
|
||||
if executor == "client" and call_id:
|
||||
self._pending_client_tool_call_ids.add(call_id)
|
||||
await self._send_event(
|
||||
|
||||
@@ -187,6 +187,17 @@ async def execute_server_tool(
|
||||
tool_name = _extract_tool_name(tool_call)
|
||||
args = _extract_tool_args(tool_call)
|
||||
resource_fetcher = tool_resource_fetcher or fetch_tool_resource
|
||||
resource: Optional[Dict[str, Any]] = None
|
||||
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
|
||||
try:
|
||||
resource = await resource_fetcher(tool_name)
|
||||
except Exception:
|
||||
resource = None
|
||||
defaults = resource.get("parameter_defaults") if isinstance(resource, dict) else None
|
||||
if isinstance(defaults, dict) and defaults:
|
||||
merged_args = dict(defaults)
|
||||
merged_args.update(args)
|
||||
args = merged_args
|
||||
|
||||
if tool_name == "calculator":
|
||||
expression = str(args.get("expression") or "").strip()
|
||||
@@ -269,7 +280,6 @@ async def execute_server_tool(
|
||||
}
|
||||
|
||||
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
|
||||
resource = await resource_fetcher(tool_name)
|
||||
if resource and str(resource.get("category") or "") == "query":
|
||||
method = str(resource.get("http_method") or "GET").strip().upper()
|
||||
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
@@ -143,6 +144,64 @@ def test_pipeline_assigns_default_client_executor_for_system_string_tools(monkey
|
||||
assert pipeline._tool_executor(tool_call) == "client"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_applies_default_args_to_tool_call(monkeypatch):
|
||||
pipeline, _events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": "call_defaults",
|
||||
"type": "function",
|
||||
"function": {"name": "weather", "arguments": "{}"},
|
||||
},
|
||||
),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
[LLMStreamEvent(type="done")],
|
||||
],
|
||||
)
|
||||
pipeline.apply_runtime_overrides(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"executor": "server",
|
||||
"defaultArgs": {"city": "Hangzhou", "unit": "c"},
|
||||
"function": {
|
||||
"name": "weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def _server_exec(call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
captured["call"] = call
|
||||
return {
|
||||
"tool_call_id": str(call.get("id") or ""),
|
||||
"name": "weather",
|
||||
"output": {"ok": True},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(pipeline, "_server_tool_executor", _server_exec)
|
||||
await pipeline._handle_turn("weather?")
|
||||
|
||||
sent_call = captured.get("call")
|
||||
assert isinstance(sent_call, dict)
|
||||
args_raw = sent_call.get("function", {}).get("arguments")
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else {}
|
||||
assert args.get("city") == "Hangzhou"
|
||||
assert args.get("unit") == "c"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_message_parses_tool_call_results():
|
||||
msg = parse_client_message(
|
||||
|
||||
@@ -2051,13 +2051,20 @@ export const DebugDrawer: React.FC<{
|
||||
const item = byId.get(id);
|
||||
const toolId = item?.id || id;
|
||||
const isClientTool = (item?.category || 'query') === 'system';
|
||||
const parameterSchema = (item?.parameterSchema && typeof item.parameterSchema === 'object')
|
||||
? item.parameterSchema
|
||||
: getDefaultToolParameters(toolId);
|
||||
const parameterDefaults = (item?.parameterDefaults && typeof item.parameterDefaults === 'object')
|
||||
? item.parameterDefaults
|
||||
: undefined;
|
||||
return {
|
||||
type: 'function',
|
||||
executor: isClientTool ? 'client' : 'server',
|
||||
...(parameterDefaults && Object.keys(parameterDefaults).length > 0 ? { defaultArgs: parameterDefaults } : {}),
|
||||
function: {
|
||||
name: toolId,
|
||||
description: item?.description || item?.name || id,
|
||||
parameters: getDefaultToolParameters(toolId),
|
||||
parameters: parameterSchema,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
@@ -20,6 +20,12 @@ const iconMap: Record<string, React.ReactNode> = {
|
||||
Volume2: <Volume2 className="w-5 h-5" />,
|
||||
};
|
||||
|
||||
const DEFAULT_PARAMETER_SCHEMA_TEXT = JSON.stringify(
|
||||
{ type: 'object', properties: {}, required: [] },
|
||||
null,
|
||||
2
|
||||
);
|
||||
|
||||
export const ToolLibraryPage: React.FC = () => {
|
||||
const [tools, setTools] = useState<Tool[]>([]);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
@@ -37,6 +43,8 @@ export const ToolLibraryPage: React.FC = () => {
|
||||
const [toolHttpUrl, setToolHttpUrl] = useState('');
|
||||
const [toolHttpHeadersText, setToolHttpHeadersText] = useState('{}');
|
||||
const [toolHttpTimeoutMs, setToolHttpTimeoutMs] = useState(10000);
|
||||
const [toolParameterSchemaText, setToolParameterSchemaText] = useState(DEFAULT_PARAMETER_SCHEMA_TEXT);
|
||||
const [toolParameterDefaultsText, setToolParameterDefaultsText] = useState('{}');
|
||||
const [saving, setSaving] = useState(false);
|
||||
|
||||
const loadTools = async () => {
|
||||
@@ -66,6 +74,8 @@ export const ToolLibraryPage: React.FC = () => {
|
||||
setToolHttpUrl('');
|
||||
setToolHttpHeadersText('{}');
|
||||
setToolHttpTimeoutMs(10000);
|
||||
setToolParameterSchemaText(DEFAULT_PARAMETER_SCHEMA_TEXT);
|
||||
setToolParameterDefaultsText('{}');
|
||||
setIsToolModalOpen(true);
|
||||
};
|
||||
|
||||
@@ -80,6 +90,8 @@ export const ToolLibraryPage: React.FC = () => {
|
||||
setToolHttpUrl(tool.httpUrl || '');
|
||||
setToolHttpHeadersText(JSON.stringify(tool.httpHeaders || {}, null, 2));
|
||||
setToolHttpTimeoutMs(tool.httpTimeoutMs || 10000);
|
||||
setToolParameterSchemaText(JSON.stringify(tool.parameterSchema || { type: 'object', properties: {}, required: [] }, null, 2));
|
||||
setToolParameterDefaultsText(JSON.stringify(tool.parameterDefaults || {}, null, 2));
|
||||
setIsToolModalOpen(true);
|
||||
};
|
||||
|
||||
@@ -153,8 +165,52 @@ export const ToolLibraryPage: React.FC = () => {
|
||||
try {
|
||||
setSaving(true);
|
||||
let parsedHeaders: Record<string, string> = {};
|
||||
let parsedParameterSchema: Record<string, any> = {};
|
||||
let parsedParameterDefaults: Record<string, any> = {};
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(toolParameterSchemaText || '{}');
|
||||
if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) {
|
||||
throw new Error('schema must be object');
|
||||
}
|
||||
parsedParameterSchema = parsed as Record<string, any>;
|
||||
} catch {
|
||||
alert('参数 Schema 必须是合法 JSON 对象');
|
||||
setSaving(false);
|
||||
return;
|
||||
}
|
||||
if (parsedParameterSchema.type && parsedParameterSchema.type !== 'object') {
|
||||
alert("参数 Schema 的 type 必须是 'object'");
|
||||
setSaving(false);
|
||||
return;
|
||||
}
|
||||
if (!parsedParameterSchema.type) parsedParameterSchema.type = 'object';
|
||||
if (!parsedParameterSchema.properties || typeof parsedParameterSchema.properties !== 'object' || Array.isArray(parsedParameterSchema.properties)) {
|
||||
parsedParameterSchema.properties = {};
|
||||
}
|
||||
if (!Array.isArray(parsedParameterSchema.required)) {
|
||||
parsedParameterSchema.required = [];
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(toolParameterDefaultsText || '{}');
|
||||
if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) {
|
||||
throw new Error('defaults must be object');
|
||||
}
|
||||
parsedParameterDefaults = parsed as Record<string, any>;
|
||||
} catch {
|
||||
alert('参数默认值必须是合法 JSON 对象');
|
||||
setSaving(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (toolCategory === 'query') {
|
||||
if (!toolHttpUrl.trim() && editingTool?.id !== 'calculator' && editingTool?.id !== 'code_interpreter') {
|
||||
if (
|
||||
!toolHttpUrl.trim()
|
||||
&& editingTool?.id !== 'calculator'
|
||||
&& editingTool?.id !== 'code_interpreter'
|
||||
&& editingTool?.id !== 'current_time'
|
||||
) {
|
||||
alert('信息查询工具请填写 HTTP URL');
|
||||
setSaving(false);
|
||||
return;
|
||||
@@ -183,6 +239,8 @@ export const ToolLibraryPage: React.FC = () => {
|
||||
httpUrl: toolHttpUrl.trim(),
|
||||
httpHeaders: parsedHeaders,
|
||||
httpTimeoutMs: toolHttpTimeoutMs,
|
||||
parameterSchema: parsedParameterSchema,
|
||||
parameterDefaults: parsedParameterDefaults,
|
||||
enabled: toolEnabled,
|
||||
});
|
||||
setTools((prev) => prev.map((item) => (item.id === updated.id ? updated : item)));
|
||||
@@ -196,6 +254,8 @@ export const ToolLibraryPage: React.FC = () => {
|
||||
httpUrl: toolHttpUrl.trim(),
|
||||
httpHeaders: parsedHeaders,
|
||||
httpTimeoutMs: toolHttpTimeoutMs,
|
||||
parameterSchema: parsedParameterSchema,
|
||||
parameterDefaults: parsedParameterDefaults,
|
||||
enabled: toolEnabled,
|
||||
});
|
||||
setTools((prev) => [created, ...prev]);
|
||||
@@ -368,6 +428,31 @@ export const ToolLibraryPage: React.FC = () => {
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4 rounded-md border border-white/10 bg-white/5 p-3">
|
||||
<div className="text-[10px] font-black uppercase tracking-widest text-emerald-300">Tool Parameters</div>
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">Schema (JSON Schema)</label>
|
||||
<textarea
|
||||
className="flex min-h-[110px] w-full rounded-md border border-white/10 bg-black/20 px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-primary/50 text-white font-mono"
|
||||
value={toolParameterSchemaText}
|
||||
onChange={(e) => setToolParameterSchemaText(e.target.value)}
|
||||
placeholder={DEFAULT_PARAMETER_SCHEMA_TEXT}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-1.5">
|
||||
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block">Default Args (JSON)</label>
|
||||
<textarea
|
||||
className="flex min-h-[90px] w-full rounded-md border border-white/10 bg-black/20 px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-primary/50 text-white font-mono"
|
||||
value={toolParameterDefaultsText}
|
||||
onChange={(e) => setToolParameterDefaultsText(e.target.value)}
|
||||
placeholder='{"step": 1}'
|
||||
/>
|
||||
</div>
|
||||
<p className="text-[11px] text-muted-foreground">
|
||||
支持系统指令和信息查询两类工具。Default Args 会在模型未传值时自动补齐。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{toolCategory === 'query' && (
|
||||
<div className="space-y-4 rounded-md border border-blue-500/20 bg-blue-500/5 p-3">
|
||||
<div className="text-[10px] font-black uppercase tracking-widest text-blue-300">HTTP Request Config</div>
|
||||
|
||||
@@ -123,6 +123,8 @@ const mapTool = (raw: AnyRecord): Tool => ({
|
||||
httpUrl: readField(raw, ['httpUrl', 'http_url'], ''),
|
||||
httpHeaders: readField(raw, ['httpHeaders', 'http_headers'], {}),
|
||||
httpTimeoutMs: Number(readField(raw, ['httpTimeoutMs', 'http_timeout_ms'], 10000)),
|
||||
parameterSchema: readField(raw, ['parameterSchema', 'parameter_schema'], {}),
|
||||
parameterDefaults: readField(raw, ['parameterDefaults', 'parameter_defaults'], {}),
|
||||
isSystem: Boolean(readField(raw, ['isSystem', 'is_system'], false)),
|
||||
enabled: Boolean(readField(raw, ['enabled'], true)),
|
||||
isCustom: !Boolean(readField(raw, ['isSystem', 'is_system'], false)),
|
||||
@@ -567,6 +569,8 @@ export const createTool = async (data: Partial<Tool>): Promise<Tool> => {
|
||||
http_url: data.httpUrl || null,
|
||||
http_headers: data.httpHeaders || {},
|
||||
http_timeout_ms: data.httpTimeoutMs ?? 10000,
|
||||
parameter_schema: data.parameterSchema || {},
|
||||
parameter_defaults: data.parameterDefaults || {},
|
||||
enabled: data.enabled ?? true,
|
||||
};
|
||||
const response = await apiRequest<AnyRecord>('/tools/resources', { method: 'POST', body: payload });
|
||||
@@ -583,6 +587,8 @@ export const updateTool = async (id: string, data: Partial<Tool>): Promise<Tool>
|
||||
http_url: data.httpUrl,
|
||||
http_headers: data.httpHeaders,
|
||||
http_timeout_ms: data.httpTimeoutMs,
|
||||
parameter_schema: data.parameterSchema,
|
||||
parameter_defaults: data.parameterDefaults,
|
||||
enabled: data.enabled,
|
||||
};
|
||||
const response = await apiRequest<AnyRecord>(`/tools/resources/${id}`, { method: 'PUT', body: payload });
|
||||
|
||||
@@ -197,6 +197,8 @@ export interface Tool {
|
||||
httpUrl?: string;
|
||||
httpHeaders?: Record<string, string>;
|
||||
httpTimeoutMs?: number;
|
||||
parameterSchema?: Record<string, any>;
|
||||
parameterDefaults?: Record<string, any>;
|
||||
isCustom?: boolean;
|
||||
isSystem?: boolean;
|
||||
enabled?: boolean;
|
||||
|
||||
Reference in New Issue
Block a user