From 9304927fe977d2ffda7d1f5dd21a47e637507f50 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Wed, 11 Feb 2026 11:04:05 +0800 Subject: [PATCH] Import tool setting --- api/app/routers/tools.py | 114 ++++++++++++++++++----- engine/core/duplex_pipeline.py | 45 +++++++++ engine/core/tool_executor.py | 141 +++++++++++++++++++++++++++++ engine/tests/test_tool_executor.py | 33 +++++++ web/pages/Assistants.tsx | 38 +++++++- 5 files changed, 344 insertions(+), 27 deletions(-) create mode 100644 engine/tests/test_tool_executor.py diff --git a/api/app/routers/tools.py b/api/app/routers/tools.py index 3602ab6..6e5db6a 100644 --- a/api/app/routers/tools.py +++ b/api/app/routers/tools.py @@ -72,6 +72,15 @@ TOOL_REGISTRY = { "required": ["query"] } }, + "current_time": { + "name": "当前时间", + "description": "获取当前本地时间", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + }, "code_interpreter": { "name": "代码执行", "description": "安全地执行Python代码", @@ -83,6 +92,37 @@ TOOL_REGISTRY = { "required": ["code"] } }, + "take_phone": { + "name": "接听电话", + "description": "执行接听电话命令", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + }, + "increase_volume": { + "name": "调高音量", + "description": "提升设备音量", + "parameters": { + "type": "object", + "properties": { + "step": {"type": "integer", "description": "调整步进,默认1"} + }, + "required": [] + } + }, + "decrease_volume": { + "name": "调低音量", + "description": "降低设备音量", + "parameters": { + "type": "object", + "properties": { + "step": {"type": "integer", "description": "调整步进,默认1"} + }, + "required": [] + } + }, } TOOL_CATEGORY_MAP = { @@ -90,8 +130,12 @@ TOOL_CATEGORY_MAP = { "weather": "query", "translate": "query", "knowledge": "query", - "calculator": "system", - "code_interpreter": "system", + "calculator": "query", + "current_time": "query", + "code_interpreter": "query", + "take_phone": "system", + "increase_volume": "system", + "decrease_volume": "system", } TOOL_ICON_MAP = { @@ -99,27 +143,51 @@ TOOL_ICON_MAP = { "weather": "CloudSun", "translate": "Globe", "knowledge": "Box", + "current_time": "Calendar", "calculator": "Terminal", "code_interpreter": "Terminal", + "take_phone": "Phone", + "increase_volume": "Volume2", + "decrease_volume": "Volume2", } -def _seed_default_tools_if_empty(db: Session) -> None: - """Seed default tools into DB when tool_resources is empty.""" - if db.query(ToolResource).count() > 0: - return - +def _sync_default_tools(db: Session) -> None: + """Ensure built-in tools exist and keep system tool metadata aligned.""" + changed = False for tool_id, payload in TOOL_REGISTRY.items(): - db.add(ToolResource( - id=tool_id, - user_id=1, - name=payload.get("name", tool_id), - description=payload.get("description", ""), - category=TOOL_CATEGORY_MAP.get(tool_id, "system"), - icon=TOOL_ICON_MAP.get(tool_id, "Wrench"), - enabled=True, - is_system=True, - )) - db.commit() + row = db.query(ToolResource).filter(ToolResource.id == tool_id).first() + category = TOOL_CATEGORY_MAP.get(tool_id, "system") + icon = TOOL_ICON_MAP.get(tool_id, "Wrench") + if not row: + db.add(ToolResource( + id=tool_id, + user_id=1, + name=payload.get("name", tool_id), + description=payload.get("description", ""), + category=category, + icon=icon, + enabled=True, + is_system=True, + )) + changed = True + continue + if row.is_system: + new_name = payload.get("name", row.name) + new_description = payload.get("description", row.description) + if row.name != new_name: + row.name = new_name + changed = True + if row.description != new_description: + row.description = new_description + changed = True + if row.category != category: + row.category = category + changed = True + if row.icon != icon: + row.icon = icon + changed = True + if changed: + db.commit() @router.get("/list") @@ -147,7 +215,7 @@ def list_tool_resources( db: Session = Depends(get_db), ): """获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。""" - _seed_default_tools_if_empty(db) + _sync_default_tools(db) query = db.query(ToolResource) if not include_system: query = query.filter(ToolResource.is_system == False) @@ -163,7 +231,7 @@ def list_tool_resources( @router.get("/resources/{id}", response_model=ToolResourceOut) def get_tool_resource(id: str, db: Session = Depends(get_db)): """获取单个工具资源详情。""" - _seed_default_tools_if_empty(db) + _sync_default_tools(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") @@ -173,7 +241,7 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)): @router.post("/resources", response_model=ToolResourceOut) def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)): """创建自定义工具资源。""" - _seed_default_tools_if_empty(db) + _sync_default_tools(db) candidate_id = (data.id or "").strip() if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first(): raise HTTPException(status_code=400, detail="Tool ID already exists") @@ -197,7 +265,7 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db) @router.put("/resources/{id}", response_model=ToolResourceOut) def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)): """更新工具资源。""" - _seed_default_tools_if_empty(db) + _sync_default_tools(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") @@ -215,7 +283,7 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend @router.delete("/resources/{id}") def delete_tool_resource(id: str, db: Session = Depends(get_db)): """删除工具资源。""" - _seed_default_tools_if_empty(db) + _sync_default_tools(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index b6b0ab1..74d45ec 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -111,6 +111,51 @@ class DuplexPipeline: "required": ["query"], }, }, + "current_time": { + "name": "current_time", + "description": "Get current local time", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + "code_interpreter": { + "name": "code_interpreter", + "description": "Execute Python code in a controlled environment", + "parameters": { + "type": "object", + "properties": {"code": {"type": "string"}}, + "required": ["code"], + }, + }, + "take_phone": { + "name": "take_phone", + "description": "Take or answer a phone call", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + "increase_volume": { + "name": "increase_volume", + "description": "Increase speaker volume", + "parameters": { + "type": "object", + "properties": {"step": {"type": "integer"}}, + "required": [], + }, + }, + "decrease_volume": { + "name": "decrease_volume", + "description": "Decrease speaker volume", + "parameters": { + "type": "object", + "properties": {"step": {"type": "integer"}}, + "required": [], + }, + }, } def __init__( diff --git a/engine/core/tool_executor.py b/engine/core/tool_executor.py index e80c73a..e6fc8f2 100644 --- a/engine/core/tool_executor.py +++ b/engine/core/tool_executor.py @@ -2,6 +2,7 @@ import ast import operator +from datetime import datetime from typing import Any, Dict @@ -18,6 +19,101 @@ _UNARY_OPS = { ast.USub: operator.neg, } +_SAFE_EVAL_FUNCS = { + "abs": abs, + "round": round, + "min": min, + "max": max, + "sum": sum, + "len": len, +} + + +def _validate_safe_expr(node: ast.AST) -> None: + """Allow only a constrained subset of Python expression nodes.""" + if isinstance(node, ast.Expression): + _validate_safe_expr(node.body) + return + + if isinstance(node, ast.Constant): + return + + if isinstance(node, (ast.List, ast.Tuple, ast.Set)): + for elt in node.elts: + _validate_safe_expr(elt) + return + + if isinstance(node, ast.Dict): + for key in node.keys: + if key is not None: + _validate_safe_expr(key) + for value in node.values: + _validate_safe_expr(value) + return + + if isinstance(node, ast.BinOp): + if type(node.op) not in _BIN_OPS: + raise ValueError("unsupported operator") + _validate_safe_expr(node.left) + _validate_safe_expr(node.right) + return + + if isinstance(node, ast.UnaryOp): + if type(node.op) not in _UNARY_OPS: + raise ValueError("unsupported unary operator") + _validate_safe_expr(node.operand) + return + + if isinstance(node, ast.BoolOp): + for value in node.values: + _validate_safe_expr(value) + return + + if isinstance(node, ast.Compare): + _validate_safe_expr(node.left) + for comp in node.comparators: + _validate_safe_expr(comp) + return + + if isinstance(node, ast.Name): + if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}: + raise ValueError("unknown symbol") + return + + if isinstance(node, ast.Call): + if not isinstance(node.func, ast.Name): + raise ValueError("unsafe call target") + if node.func.id not in _SAFE_EVAL_FUNCS: + raise ValueError("function not allowed") + for arg in node.args: + _validate_safe_expr(arg) + for kw in node.keywords: + _validate_safe_expr(kw.value) + return + + # Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.) + raise ValueError("unsupported expression") + + +def _safe_eval_python_expr(expression: str) -> Any: + tree = ast.parse(expression, mode="eval") + _validate_safe_expr(tree) + return eval( # noqa: S307 - validated AST + empty builtins + compile(tree, "", "eval"), + {"__builtins__": {}}, + dict(_SAFE_EVAL_FUNCS), + ) + + +def _json_safe(value: Any) -> Any: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, (list, tuple)): + return [_json_safe(v) for v in value] + if isinstance(value, dict): + return {str(k): _json_safe(v) for k, v in value.items()} + return repr(value) + def _safe_eval_expr(expression: str) -> float: tree = ast.parse(expression, mode="eval") @@ -110,6 +206,51 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: "status": {"code": 422, "message": "invalid_expression"}, } + if tool_name == "current_time": + now = datetime.now().astimezone() + return { + "tool_call_id": call_id, + "name": tool_name, + "output": { + "iso": now.isoformat(), + "local": now.strftime("%Y-%m-%d %H:%M:%S %Z"), + "unix": int(now.timestamp()), + }, + "status": {"code": 200, "message": "ok"}, + } + + if tool_name == "code_interpreter": + code = str(args.get("code") or args.get("expression") or "").strip() + if not code: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"error": "missing code"}, + "status": {"code": 400, "message": "bad_request"}, + } + if len(code) > 500: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"error": "code too long"}, + "status": {"code": 422, "message": "invalid_code"}, + } + try: + result = _safe_eval_python_expr(code) + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"code": code, "result": _json_safe(result)}, + "status": {"code": 200, "message": "ok"}, + } + except Exception as exc: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"code": code, "error": str(exc)}, + "status": {"code": 422, "message": "invalid_code"}, + } + return { "tool_call_id": call_id, "name": tool_name or "unknown_tool", diff --git a/engine/tests/test_tool_executor.py b/engine/tests/test_tool_executor.py new file mode 100644 index 0000000..a6bde28 --- /dev/null +++ b/engine/tests/test_tool_executor.py @@ -0,0 +1,33 @@ +import pytest + +from core.tool_executor import execute_server_tool + + +@pytest.mark.asyncio +async def test_code_interpreter_simple_expression(): + result = await execute_server_tool( + { + "id": "call_ci_ok", + "function": { + "name": "code_interpreter", + "arguments": '{"code":"sum([1, 2, 3]) + 4"}', + }, + } + ) + assert result["status"]["code"] == 200 + assert result["output"]["result"] == 10 + + +@pytest.mark.asyncio +async def test_code_interpreter_blocks_import_and_io(): + result = await execute_server_tool( + { + "id": "call_ci_bad", + "function": { + "name": "code_interpreter", + "arguments": '{"code":"__import__(\\"os\\").system(\\"ls\\")"}', + }, + } + ) + assert result["status"]["code"] == 422 + assert result["status"]["message"] == "invalid_code" diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index 58b09ab..4a264f2 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -42,6 +42,8 @@ const renderToolIcon = (icon: string) => { Images: , CloudSun: , Calendar: , + Phone: , + Volume2: , TrendingUp: , Coins: , Terminal: , @@ -1017,6 +1019,37 @@ const TOOL_PARAMETER_HINTS: Record = { }, required: ['query'], }, + current_time: { + type: 'object', + properties: {}, + required: [], + }, + take_phone: { + type: 'object', + properties: {}, + required: [], + }, + increase_volume: { + type: 'object', + properties: { + step: { type: 'integer', description: 'Volume step, default 1' }, + }, + required: [], + }, + decrease_volume: { + type: 'object', + properties: { + step: { type: 'integer', description: 'Volume step, default 1' }, + }, + required: [], + }, + code_interpreter: { + type: 'object', + properties: { + code: { type: 'string', description: 'Python code' }, + }, + required: ['code'], + }, }; const getDefaultToolParameters = (toolId: string) => @@ -1158,10 +1191,7 @@ export const DebugDrawer: React.FC<{ return ids.map((id) => { const item = byId.get(id); const toolId = item?.id || id; - const isClientTool = - toolId.startsWith('client_') || - toolId.startsWith('browser_') || - ['camera', 'microphone', 'page_control', 'local_file'].includes(toolId); + const isClientTool = (item?.category || 'query') === 'system'; return { type: 'function', executor: isClientTool ? 'client' : 'server',