diff --git a/api/app/models.py b/api/app/models.py index 3aa5bf6..7ceabaf 100644 --- a/api/app/models.py +++ b/api/app/models.py @@ -96,6 +96,8 @@ class ToolResource(Base): http_url: Mapped[Optional[str]] = mapped_column(String(1024), nullable=True) http_headers: Mapped[dict] = mapped_column(JSON, default=dict) http_timeout_ms: Mapped[int] = mapped_column(Integer, default=10000) + parameter_schema: Mapped[dict] = mapped_column(JSON, default=dict) + parameter_defaults: Mapped[dict] = mapped_column(JSON, default=dict) enabled: Mapped[bool] = mapped_column(default=True) is_system: Mapped[bool] = mapped_column(default=False) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index fcb1932..66bb958 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -13,7 +13,7 @@ import uuid from datetime import datetime from ..db import get_db -from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice +from ..models import Assistant, AssistantOpenerAudio, LLMModel, ASRModel, Voice, ToolResource from ..schemas import ( AssistantCreate, AssistantUpdate, @@ -22,6 +22,7 @@ from ..schemas import ( AssistantOpenerAudioGenerateRequest, AssistantOpenerAudioOut, ) +from .tools import TOOL_REGISTRY, TOOL_CATEGORY_MAP, TOOL_PARAMETER_DEFAULTS, _ensure_tool_resource_schema router = APIRouter(prefix="/assistants", tags=["Assistants"]) @@ -78,7 +79,75 @@ def _config_version_id(assistant: Assistant) -> str: return f"asst_{assistant.id}_{updated.strftime('%Y%m%d%H%M%S')}" +def _normalize_runtime_tool_schema(tool_id: str, raw_schema: Any) -> Dict[str, Any]: + schema = dict(raw_schema) if isinstance(raw_schema, dict) else {} + if not schema: + fallback = TOOL_REGISTRY.get(tool_id, {}).get("parameters") + if isinstance(fallback, dict): + schema = dict(fallback) + schema.setdefault("type", "object") + if not isinstance(schema.get("properties"), dict): + schema["properties"] = {} + required = schema.get("required") + if required is None or not isinstance(required, list): + schema["required"] = [] + return schema + + +def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings: List[str]) -> List[Dict[str, Any]]: + _ensure_tool_resource_schema(db) + ids = [str(tool_id).strip() for tool_id in selected_tool_ids if str(tool_id).strip()] + if not ids: + return [] + + resources = ( + db.query(ToolResource) + .filter(ToolResource.id.in_(ids)) + .all() + ) + by_id = {str(item.id): item for item in resources} + + runtime_tools: List[Dict[str, Any]] = [] + for tool_id in ids: + resource = by_id.get(tool_id) + if resource and resource.enabled is False: + warnings.append(f"Tool is disabled and skipped in runtime config: {tool_id}") + continue + + category = str(resource.category if resource else TOOL_CATEGORY_MAP.get(tool_id, "query")) + description = ( + str(resource.description or resource.name or "").strip() + if resource + else str(TOOL_REGISTRY.get(tool_id, {}).get("description") or "").strip() + ) + schema = _normalize_runtime_tool_schema( + tool_id, + resource.parameter_schema if resource else TOOL_REGISTRY.get(tool_id, {}).get("parameters"), + ) + defaults_raw = resource.parameter_defaults if resource else TOOL_PARAMETER_DEFAULTS.get(tool_id) + defaults = dict(defaults_raw) if isinstance(defaults_raw, dict) else {} + + if not resource and tool_id not in TOOL_REGISTRY: + warnings.append(f"Tool resource not found: {tool_id}") + + runtime_tool: Dict[str, Any] = { + "type": "function", + "executor": "client" if category == "system" else "server", + "function": { + "name": tool_id, + "description": description or tool_id, + "parameters": schema, + }, + } + if defaults: + runtime_tool["defaultArgs"] = defaults + runtime_tools.append(runtime_tool) + + return runtime_tools + + def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]: + warnings: List[str] = [] metadata: Dict[str, Any] = { "systemPrompt": assistant.prompt or "", "firstTurnMode": assistant.first_turn_mode or "bot_first", @@ -90,14 +159,13 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s "minDurationMs": int(assistant.interruption_sensitivity or 500), }, "services": {}, - "tools": assistant.tools or [], + "tools": _resolve_runtime_tools(db, assistant.tools or [], warnings), "history": { "assistantId": assistant.id, "userId": int(assistant.user_id or 1), "source": "debug", }, } - warnings: List[str] = [] config_mode = str(assistant.config_mode or "platform").strip().lower() diff --git a/api/app/routers/tools.py b/api/app/routers/tools.py index 0af09e0..eab3f91 100644 --- a/api/app/routers/tools.py +++ b/api/app/routers/tools.py @@ -1,5 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session +from sqlalchemy import inspect, text from typing import Optional, Dict, Any, List import time import uuid @@ -111,6 +112,61 @@ TOOL_ICON_MAP = { TOOL_HTTP_DEFAULTS = { } +TOOL_PARAMETER_DEFAULTS = { + "increase_volume": {"step": 1}, + "decrease_volume": {"step": 1}, +} + + +def _normalize_parameter_schema(value: Any, *, tool_id: Optional[str] = None) -> Dict[str, Any]: + if not isinstance(value, dict): + value = {} + normalized = dict(value) + if not normalized: + fallback = TOOL_REGISTRY.get(str(tool_id or "").strip(), {}).get("parameters") + if isinstance(fallback, dict): + normalized = dict(fallback) + normalized.setdefault("type", "object") + if normalized.get("type") != "object": + raise HTTPException(status_code=400, detail="parameter_schema.type must be 'object'") + properties = normalized.get("properties") + if not isinstance(properties, dict): + normalized["properties"] = {} + required = normalized.get("required") + if required is None: + normalized["required"] = [] + elif not isinstance(required, list): + raise HTTPException(status_code=400, detail="parameter_schema.required must be an array") + return normalized + + +def _normalize_parameter_defaults(value: Any) -> Dict[str, Any]: + if value is None: + return {} + if not isinstance(value, dict): + raise HTTPException(status_code=400, detail="parameter_defaults must be an object") + return dict(value) + + +def _ensure_tool_resource_schema(db: Session) -> None: + """Apply lightweight SQLite migrations for newly added tool_resources columns.""" + bind = db.get_bind() + inspector = inspect(bind) + try: + columns = {col["name"] for col in inspector.get_columns("tool_resources")} + except Exception: + return + + altered = False + if "parameter_schema" not in columns: + db.execute(text("ALTER TABLE tool_resources ADD COLUMN parameter_schema JSON")) + altered = True + if "parameter_defaults" not in columns: + db.execute(text("ALTER TABLE tool_resources ADD COLUMN parameter_defaults JSON")) + altered = True + if altered: + db.commit() + def _normalize_http_method(method: Optional[str]) -> str: normalized = str(method or "GET").strip().upper() @@ -127,8 +183,10 @@ def _validate_query_http_config(*, category: str, tool_id: Optional[str], http_u if _requires_http_request(category, tool_id) and not str(http_url or "").strip(): raise HTTPException(status_code=400, detail="http_url is required for query tools (except calculator/code_interpreter)") + def _seed_default_tools_if_empty(db: Session) -> None: """Seed built-in tools only when tool_resources is empty.""" + _ensure_tool_resource_schema(db) if db.query(ToolResource).count() > 0: return for tool_id, payload in TOOL_REGISTRY.items(): @@ -144,6 +202,8 @@ def _seed_default_tools_if_empty(db: Session) -> None: http_url=http_defaults.get("http_url"), http_headers=http_defaults.get("http_headers") or {}, http_timeout_ms=int(http_defaults.get("http_timeout_ms") or 10000), + parameter_schema=_normalize_parameter_schema(payload.get("parameters"), tool_id=tool_id), + parameter_defaults=_normalize_parameter_defaults(TOOL_PARAMETER_DEFAULTS.get(tool_id)), enabled=True, is_system=True, )) @@ -215,6 +275,8 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db) raise HTTPException(status_code=400, detail="Tool ID already exists") _validate_query_http_config(category=data.category, tool_id=candidate_id, http_url=data.http_url) + parameter_schema = _normalize_parameter_schema(data.parameter_schema, tool_id=candidate_id) + parameter_defaults = _normalize_parameter_defaults(data.parameter_defaults) item = ToolResource( id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}", @@ -227,6 +289,8 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db) http_url=(data.http_url or "").strip() or None, http_headers=data.http_headers or {}, http_timeout_ms=max(1000, int(data.http_timeout_ms or 10000)), + parameter_schema=parameter_schema, + parameter_defaults=parameter_defaults, enabled=data.enabled, is_system=False, ) @@ -254,6 +318,10 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend update_data["http_method"] = _normalize_http_method(update_data.get("http_method")) if "http_timeout_ms" in update_data and update_data.get("http_timeout_ms") is not None: update_data["http_timeout_ms"] = max(1000, int(update_data["http_timeout_ms"])) + if "parameter_schema" in update_data: + update_data["parameter_schema"] = _normalize_parameter_schema(update_data.get("parameter_schema"), tool_id=id) + if "parameter_defaults" in update_data: + update_data["parameter_defaults"] = _normalize_parameter_defaults(update_data.get("parameter_defaults")) for field, value in update_data.items(): setattr(item, field, value) diff --git a/api/app/schemas.py b/api/app/schemas.py index f81efc8..dbdb806 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -239,6 +239,8 @@ class ToolResourceBase(BaseModel): http_url: Optional[str] = None http_headers: Dict[str, str] = Field(default_factory=dict) http_timeout_ms: int = 10000 + parameter_schema: Dict[str, Any] = Field(default_factory=dict) + parameter_defaults: Dict[str, Any] = Field(default_factory=dict) enabled: bool = True @@ -255,6 +257,8 @@ class ToolResourceUpdate(BaseModel): http_url: Optional[str] = None http_headers: Optional[Dict[str, str]] = None http_timeout_ms: Optional[int] = None + parameter_schema: Optional[Dict[str, Any]] = None + parameter_defaults: Optional[Dict[str, Any]] = None enabled: Optional[bool] = None diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index 6b2d173..d1fc52d 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -243,6 +243,26 @@ class TestAssistantAPI: assert payload["sessionStartMetadata"]["systemPrompt"] == sample_assistant_data["prompt"] assert payload["sessionStartMetadata"]["history"]["assistantId"] == assistant_id + def test_runtime_config_resolves_selected_tools_into_runtime_definitions(self, client, sample_assistant_data): + sample_assistant_data["tools"] = ["increase_volume", "calculator"] + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + tools = metadata["tools"] + assert isinstance(tools, list) + assert len(tools) == 2 + + by_name = {item["function"]["name"]: item for item in tools} + assert by_name["increase_volume"]["executor"] == "client" + assert by_name["increase_volume"]["defaultArgs"]["step"] == 1 + assert by_name["calculator"]["executor"] == "server" + assert by_name["calculator"]["function"]["parameters"]["type"] == "object" + assert "expression" in by_name["calculator"]["function"]["parameters"]["properties"] + def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data): sample_assistant_data["voiceOutputEnabled"] = False assistant_resp = client.post("/api/assistants", json=sample_assistant_data) diff --git a/api/tests/test_tools.py b/api/tests/test_tools.py index c4b07d9..daa1e85 100644 --- a/api/tests/test_tools.py +++ b/api/tests/test_tools.py @@ -281,6 +281,9 @@ class TestToolResourceCRUD: assert payload["total"] >= 1 ids = [item["id"] for item in payload["list"]] assert "calculator" in ids + calculator = next((item for item in payload["list"] if item["id"] == "calculator"), None) + assert calculator is not None + assert calculator["parameter_schema"]["type"] == "object" def test_create_update_delete_tool_resource(self, client): create_resp = client.post("/api/tools/resources", json={ @@ -292,6 +295,12 @@ class TestToolResourceCRUD: "http_url": "https://example.com/search", "http_headers": {}, "http_timeout_ms": 10000, + "parameter_schema": { + "type": "object", + "properties": {"keyword": {"type": "string"}}, + "required": ["keyword"] + }, + "parameter_defaults": {"limit": 10}, "enabled": True, }) assert create_resp.status_code == 200 @@ -299,15 +308,19 @@ class TestToolResourceCRUD: tool_id = created["id"] assert created["name"] == "自定义网页抓取" assert created["is_system"] is False + assert created["parameter_schema"]["required"] == ["keyword"] + assert created["parameter_defaults"]["limit"] == 10 update_resp = client.put(f"/api/tools/resources/{tool_id}", json={ "name": "自定义网页检索", "category": "system", + "parameter_defaults": {"limit": 20}, }) assert update_resp.status_code == 200 updated = update_resp.json() assert updated["name"] == "自定义网页检索" assert updated["category"] == "system" + assert updated["parameter_defaults"]["limit"] == 20 get_resp = client.get(f"/api/tools/resources/{tool_id}") assert get_resp.status_code == 200 diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index e879012..1f2485b 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -287,6 +287,7 @@ class DuplexPipeline: raw_default_tools = settings.tools if isinstance(settings.tools, list) else [] self._runtime_tools: List[Any] = list(raw_default_tools) self._runtime_tool_executor: Dict[str, str] = {} + self._runtime_tool_default_args: Dict[str, Dict[str, Any]] = {} self._pending_tool_waiters: Dict[str, asyncio.Future] = {} self._early_tool_results: Dict[str, Dict[str, Any]] = {} self._completed_tool_call_ids: set[str] = set() @@ -307,6 +308,7 @@ class DuplexPipeline: self._last_llm_delta_emit_ms: float = 0.0 self._runtime_tool_executor = self._resolved_tool_executor_map() + self._runtime_tool_default_args = self._resolved_tool_default_args_map() self._initial_greeting_emitted = False if self._server_tool_executor is None: @@ -408,9 +410,11 @@ class DuplexPipeline: if isinstance(tools_payload, list): self._runtime_tools = tools_payload self._runtime_tool_executor = self._resolved_tool_executor_map() + self._runtime_tool_default_args = self._resolved_tool_default_args_map() elif "tools" in metadata: self._runtime_tools = [] self._runtime_tool_executor = {} + self._runtime_tool_default_args = {} if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"): self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) @@ -1473,6 +1477,25 @@ class DuplexPipeline: result[name] = executor return result + def _resolved_tool_default_args_map(self) -> Dict[str, Dict[str, Any]]: + result: Dict[str, Dict[str, Any]] = {} + for item in self._runtime_tools: + if not isinstance(item, dict): + continue + fn = item.get("function") + if isinstance(fn, dict) and fn.get("name"): + name = str(fn.get("name")).strip() + else: + name = str(item.get("name") or "").strip() + if not name: + continue + raw_defaults = item.get("defaultArgs") + if raw_defaults is None: + raw_defaults = item.get("default_args") + if isinstance(raw_defaults, dict): + result[name] = dict(raw_defaults) + return result + def _resolved_tool_allowlist(self) -> List[str]: names: set[str] = set() for item in self._runtime_tools: @@ -1518,6 +1541,15 @@ class DuplexPipeline: return {"raw": raw} return {} + def _apply_tool_default_args(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]: + defaults = self._runtime_tool_default_args.get(tool_name) + if not isinstance(defaults, dict) or not defaults: + return args + merged = dict(defaults) + if isinstance(args, dict): + merged.update(args) + return merged + def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]: status = result.get("status") if isinstance(result.get("status"), dict) else {} status_code = int(status.get("code") or 0) if status else 0 @@ -1702,15 +1734,27 @@ class DuplexPipeline: enriched_tool_call["executor"] = executor tool_name = self._tool_name(enriched_tool_call) or "unknown_tool" call_id = str(enriched_tool_call.get("id") or "").strip() - fn_payload = enriched_tool_call.get("function") + fn_payload = ( + dict(enriched_tool_call.get("function")) + if isinstance(enriched_tool_call.get("function"), dict) + else None + ) raw_args = str(fn_payload.get("arguments") or "") if isinstance(fn_payload, dict) else "" + tool_arguments = self._tool_arguments(enriched_tool_call) + merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments) + try: + merged_args_text = json.dumps(merged_tool_arguments, ensure_ascii=False) + except Exception: + merged_args_text = raw_args if raw_args else "{}" + if isinstance(fn_payload, dict): + fn_payload["arguments"] = merged_args_text + enriched_tool_call["function"] = fn_payload args_preview = raw_args if len(raw_args) <= 160 else f"{raw_args[:160]}..." logger.info( f"[Tool] call requested name={tool_name} call_id={call_id} " - f"executor={executor} args={args_preview}" + f"executor={executor} args={args_preview} merged_args={merged_args_text}" ) tool_calls.append(enriched_tool_call) - tool_arguments = self._tool_arguments(enriched_tool_call) if executor == "client" and call_id: self._pending_client_tool_call_ids.add(call_id) await self._send_event( diff --git a/engine/core/tool_executor.py b/engine/core/tool_executor.py index 4505436..899d930 100644 --- a/engine/core/tool_executor.py +++ b/engine/core/tool_executor.py @@ -187,6 +187,17 @@ async def execute_server_tool( tool_name = _extract_tool_name(tool_call) args = _extract_tool_args(tool_call) resource_fetcher = tool_resource_fetcher or fetch_tool_resource + resource: Optional[Dict[str, Any]] = None + if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}: + try: + resource = await resource_fetcher(tool_name) + except Exception: + resource = None + defaults = resource.get("parameter_defaults") if isinstance(resource, dict) else None + if isinstance(defaults, dict) and defaults: + merged_args = dict(defaults) + merged_args.update(args) + args = merged_args if tool_name == "calculator": expression = str(args.get("expression") or "").strip() @@ -269,7 +280,6 @@ async def execute_server_tool( } if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}: - resource = await resource_fetcher(tool_name) if resource and str(resource.get("category") or "") == "query": method = str(resource.get("http_method") or "GET").strip().upper() if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}: diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index 8e44815..32810a2 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -1,4 +1,5 @@ import asyncio +import json from typing import Any, Dict, List import pytest @@ -143,6 +144,64 @@ def test_pipeline_assigns_default_client_executor_for_system_string_tools(monkey assert pipeline._tool_executor(tool_call) == "client" +@pytest.mark.asyncio +async def test_pipeline_applies_default_args_to_tool_call(monkeypatch): + pipeline, _events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_defaults", + "type": "function", + "function": {"name": "weather", "arguments": "{}"}, + }, + ), + LLMStreamEvent(type="done"), + ], + [LLMStreamEvent(type="done")], + ], + ) + pipeline.apply_runtime_overrides( + { + "tools": [ + { + "type": "function", + "executor": "server", + "defaultArgs": {"city": "Hangzhou", "unit": "c"}, + "function": { + "name": "weather", + "description": "Get weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + }, + } + ] + } + ) + + captured: Dict[str, Any] = {} + + async def _server_exec(call: Dict[str, Any]) -> Dict[str, Any]: + captured["call"] = call + return { + "tool_call_id": str(call.get("id") or ""), + "name": "weather", + "output": {"ok": True}, + "status": {"code": 200, "message": "ok"}, + } + + monkeypatch.setattr(pipeline, "_server_tool_executor", _server_exec) + await pipeline._handle_turn("weather?") + + sent_call = captured.get("call") + assert isinstance(sent_call, dict) + args_raw = sent_call.get("function", {}).get("arguments") + args = json.loads(args_raw) if isinstance(args_raw, str) else {} + assert args.get("city") == "Hangzhou" + assert args.get("unit") == "c" + + @pytest.mark.asyncio async def test_ws_message_parses_tool_call_results(): msg = parse_client_message( diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index ea72b26..90fdccc 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -2051,13 +2051,20 @@ export const DebugDrawer: React.FC<{ const item = byId.get(id); const toolId = item?.id || id; const isClientTool = (item?.category || 'query') === 'system'; + const parameterSchema = (item?.parameterSchema && typeof item.parameterSchema === 'object') + ? item.parameterSchema + : getDefaultToolParameters(toolId); + const parameterDefaults = (item?.parameterDefaults && typeof item.parameterDefaults === 'object') + ? item.parameterDefaults + : undefined; return { type: 'function', executor: isClientTool ? 'client' : 'server', + ...(parameterDefaults && Object.keys(parameterDefaults).length > 0 ? { defaultArgs: parameterDefaults } : {}), function: { name: toolId, description: item?.description || item?.name || id, - parameters: getDefaultToolParameters(toolId), + parameters: parameterSchema, }, }; }); diff --git a/web/pages/ToolLibrary.tsx b/web/pages/ToolLibrary.tsx index f3d27cf..4a07f06 100644 --- a/web/pages/ToolLibrary.tsx +++ b/web/pages/ToolLibrary.tsx @@ -20,6 +20,12 @@ const iconMap: Record = { Volume2: , }; +const DEFAULT_PARAMETER_SCHEMA_TEXT = JSON.stringify( + { type: 'object', properties: {}, required: [] }, + null, + 2 +); + export const ToolLibraryPage: React.FC = () => { const [tools, setTools] = useState([]); const [searchTerm, setSearchTerm] = useState(''); @@ -37,6 +43,8 @@ export const ToolLibraryPage: React.FC = () => { const [toolHttpUrl, setToolHttpUrl] = useState(''); const [toolHttpHeadersText, setToolHttpHeadersText] = useState('{}'); const [toolHttpTimeoutMs, setToolHttpTimeoutMs] = useState(10000); + const [toolParameterSchemaText, setToolParameterSchemaText] = useState(DEFAULT_PARAMETER_SCHEMA_TEXT); + const [toolParameterDefaultsText, setToolParameterDefaultsText] = useState('{}'); const [saving, setSaving] = useState(false); const loadTools = async () => { @@ -66,6 +74,8 @@ export const ToolLibraryPage: React.FC = () => { setToolHttpUrl(''); setToolHttpHeadersText('{}'); setToolHttpTimeoutMs(10000); + setToolParameterSchemaText(DEFAULT_PARAMETER_SCHEMA_TEXT); + setToolParameterDefaultsText('{}'); setIsToolModalOpen(true); }; @@ -80,6 +90,8 @@ export const ToolLibraryPage: React.FC = () => { setToolHttpUrl(tool.httpUrl || ''); setToolHttpHeadersText(JSON.stringify(tool.httpHeaders || {}, null, 2)); setToolHttpTimeoutMs(tool.httpTimeoutMs || 10000); + setToolParameterSchemaText(JSON.stringify(tool.parameterSchema || { type: 'object', properties: {}, required: [] }, null, 2)); + setToolParameterDefaultsText(JSON.stringify(tool.parameterDefaults || {}, null, 2)); setIsToolModalOpen(true); }; @@ -153,8 +165,52 @@ export const ToolLibraryPage: React.FC = () => { try { setSaving(true); let parsedHeaders: Record = {}; + let parsedParameterSchema: Record = {}; + let parsedParameterDefaults: Record = {}; + + try { + const parsed = JSON.parse(toolParameterSchemaText || '{}'); + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + throw new Error('schema must be object'); + } + parsedParameterSchema = parsed as Record; + } catch { + alert('参数 Schema 必须是合法 JSON 对象'); + setSaving(false); + return; + } + if (parsedParameterSchema.type && parsedParameterSchema.type !== 'object') { + alert("参数 Schema 的 type 必须是 'object'"); + setSaving(false); + return; + } + if (!parsedParameterSchema.type) parsedParameterSchema.type = 'object'; + if (!parsedParameterSchema.properties || typeof parsedParameterSchema.properties !== 'object' || Array.isArray(parsedParameterSchema.properties)) { + parsedParameterSchema.properties = {}; + } + if (!Array.isArray(parsedParameterSchema.required)) { + parsedParameterSchema.required = []; + } + + try { + const parsed = JSON.parse(toolParameterDefaultsText || '{}'); + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + throw new Error('defaults must be object'); + } + parsedParameterDefaults = parsed as Record; + } catch { + alert('参数默认值必须是合法 JSON 对象'); + setSaving(false); + return; + } + if (toolCategory === 'query') { - if (!toolHttpUrl.trim() && editingTool?.id !== 'calculator' && editingTool?.id !== 'code_interpreter') { + if ( + !toolHttpUrl.trim() + && editingTool?.id !== 'calculator' + && editingTool?.id !== 'code_interpreter' + && editingTool?.id !== 'current_time' + ) { alert('信息查询工具请填写 HTTP URL'); setSaving(false); return; @@ -183,6 +239,8 @@ export const ToolLibraryPage: React.FC = () => { httpUrl: toolHttpUrl.trim(), httpHeaders: parsedHeaders, httpTimeoutMs: toolHttpTimeoutMs, + parameterSchema: parsedParameterSchema, + parameterDefaults: parsedParameterDefaults, enabled: toolEnabled, }); setTools((prev) => prev.map((item) => (item.id === updated.id ? updated : item))); @@ -196,6 +254,8 @@ export const ToolLibraryPage: React.FC = () => { httpUrl: toolHttpUrl.trim(), httpHeaders: parsedHeaders, httpTimeoutMs: toolHttpTimeoutMs, + parameterSchema: parsedParameterSchema, + parameterDefaults: parsedParameterDefaults, enabled: toolEnabled, }); setTools((prev) => [created, ...prev]); @@ -368,6 +428,31 @@ export const ToolLibraryPage: React.FC = () => { /> +
+
Tool Parameters
+
+ +