Import tool setting

This commit is contained in:
Xin Wang
2026-02-11 11:04:05 +08:00
parent 180a69ca67
commit 9304927fe9
5 changed files with 344 additions and 27 deletions

View File

@@ -72,6 +72,15 @@ TOOL_REGISTRY = {
"required": ["query"] "required": ["query"]
} }
}, },
"current_time": {
"name": "当前时间",
"description": "获取当前本地时间",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"code_interpreter": { "code_interpreter": {
"name": "代码执行", "name": "代码执行",
"description": "安全地执行Python代码", "description": "安全地执行Python代码",
@@ -83,6 +92,37 @@ TOOL_REGISTRY = {
"required": ["code"] "required": ["code"]
} }
}, },
"take_phone": {
"name": "接听电话",
"description": "执行接听电话命令",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"increase_volume": {
"name": "调高音量",
"description": "提升设备音量",
"parameters": {
"type": "object",
"properties": {
"step": {"type": "integer", "description": "调整步进默认1"}
},
"required": []
}
},
"decrease_volume": {
"name": "调低音量",
"description": "降低设备音量",
"parameters": {
"type": "object",
"properties": {
"step": {"type": "integer", "description": "调整步进默认1"}
},
"required": []
}
},
} }
TOOL_CATEGORY_MAP = { TOOL_CATEGORY_MAP = {
@@ -90,8 +130,12 @@ TOOL_CATEGORY_MAP = {
"weather": "query", "weather": "query",
"translate": "query", "translate": "query",
"knowledge": "query", "knowledge": "query",
"calculator": "system", "calculator": "query",
"code_interpreter": "system", "current_time": "query",
"code_interpreter": "query",
"take_phone": "system",
"increase_volume": "system",
"decrease_volume": "system",
} }
TOOL_ICON_MAP = { TOOL_ICON_MAP = {
@@ -99,26 +143,50 @@ TOOL_ICON_MAP = {
"weather": "CloudSun", "weather": "CloudSun",
"translate": "Globe", "translate": "Globe",
"knowledge": "Box", "knowledge": "Box",
"current_time": "Calendar",
"calculator": "Terminal", "calculator": "Terminal",
"code_interpreter": "Terminal", "code_interpreter": "Terminal",
"take_phone": "Phone",
"increase_volume": "Volume2",
"decrease_volume": "Volume2",
} }
def _seed_default_tools_if_empty(db: Session) -> None: def _sync_default_tools(db: Session) -> None:
"""Seed default tools into DB when tool_resources is empty.""" """Ensure built-in tools exist and keep system tool metadata aligned."""
if db.query(ToolResource).count() > 0: changed = False
return
for tool_id, payload in TOOL_REGISTRY.items(): for tool_id, payload in TOOL_REGISTRY.items():
row = db.query(ToolResource).filter(ToolResource.id == tool_id).first()
category = TOOL_CATEGORY_MAP.get(tool_id, "system")
icon = TOOL_ICON_MAP.get(tool_id, "Wrench")
if not row:
db.add(ToolResource( db.add(ToolResource(
id=tool_id, id=tool_id,
user_id=1, user_id=1,
name=payload.get("name", tool_id), name=payload.get("name", tool_id),
description=payload.get("description", ""), description=payload.get("description", ""),
category=TOOL_CATEGORY_MAP.get(tool_id, "system"), category=category,
icon=TOOL_ICON_MAP.get(tool_id, "Wrench"), icon=icon,
enabled=True, enabled=True,
is_system=True, is_system=True,
)) ))
changed = True
continue
if row.is_system:
new_name = payload.get("name", row.name)
new_description = payload.get("description", row.description)
if row.name != new_name:
row.name = new_name
changed = True
if row.description != new_description:
row.description = new_description
changed = True
if row.category != category:
row.category = category
changed = True
if row.icon != icon:
row.icon = icon
changed = True
if changed:
db.commit() db.commit()
@@ -147,7 +215,7 @@ def list_tool_resources(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。""" """获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。"""
_seed_default_tools_if_empty(db) _sync_default_tools(db)
query = db.query(ToolResource) query = db.query(ToolResource)
if not include_system: if not include_system:
query = query.filter(ToolResource.is_system == False) query = query.filter(ToolResource.is_system == False)
@@ -163,7 +231,7 @@ def list_tool_resources(
@router.get("/resources/{id}", response_model=ToolResourceOut) @router.get("/resources/{id}", response_model=ToolResourceOut)
def get_tool_resource(id: str, db: Session = Depends(get_db)): def get_tool_resource(id: str, db: Session = Depends(get_db)):
"""获取单个工具资源详情。""" """获取单个工具资源详情。"""
_seed_default_tools_if_empty(db) _sync_default_tools(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first() item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Tool resource not found") raise HTTPException(status_code=404, detail="Tool resource not found")
@@ -173,7 +241,7 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)):
@router.post("/resources", response_model=ToolResourceOut) @router.post("/resources", response_model=ToolResourceOut)
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)): def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
"""创建自定义工具资源。""" """创建自定义工具资源。"""
_seed_default_tools_if_empty(db) _sync_default_tools(db)
candidate_id = (data.id or "").strip() candidate_id = (data.id or "").strip()
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first(): if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
raise HTTPException(status_code=400, detail="Tool ID already exists") raise HTTPException(status_code=400, detail="Tool ID already exists")
@@ -197,7 +265,7 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
@router.put("/resources/{id}", response_model=ToolResourceOut) @router.put("/resources/{id}", response_model=ToolResourceOut)
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)): def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
"""更新工具资源。""" """更新工具资源。"""
_seed_default_tools_if_empty(db) _sync_default_tools(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first() item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Tool resource not found") raise HTTPException(status_code=404, detail="Tool resource not found")
@@ -215,7 +283,7 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
@router.delete("/resources/{id}") @router.delete("/resources/{id}")
def delete_tool_resource(id: str, db: Session = Depends(get_db)): def delete_tool_resource(id: str, db: Session = Depends(get_db)):
"""删除工具资源。""" """删除工具资源。"""
_seed_default_tools_if_empty(db) _sync_default_tools(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first() item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Tool resource not found") raise HTTPException(status_code=404, detail="Tool resource not found")

View File

@@ -111,6 +111,51 @@ class DuplexPipeline:
"required": ["query"], "required": ["query"],
}, },
}, },
"current_time": {
"name": "current_time",
"description": "Get current local time",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
"code_interpreter": {
"name": "code_interpreter",
"description": "Execute Python code in a controlled environment",
"parameters": {
"type": "object",
"properties": {"code": {"type": "string"}},
"required": ["code"],
},
},
"take_phone": {
"name": "take_phone",
"description": "Take or answer a phone call",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
"increase_volume": {
"name": "increase_volume",
"description": "Increase speaker volume",
"parameters": {
"type": "object",
"properties": {"step": {"type": "integer"}},
"required": [],
},
},
"decrease_volume": {
"name": "decrease_volume",
"description": "Decrease speaker volume",
"parameters": {
"type": "object",
"properties": {"step": {"type": "integer"}},
"required": [],
},
},
} }
def __init__( def __init__(

View File

@@ -2,6 +2,7 @@
import ast import ast
import operator import operator
from datetime import datetime
from typing import Any, Dict from typing import Any, Dict
@@ -18,6 +19,101 @@ _UNARY_OPS = {
ast.USub: operator.neg, ast.USub: operator.neg,
} }
_SAFE_EVAL_FUNCS = {
"abs": abs,
"round": round,
"min": min,
"max": max,
"sum": sum,
"len": len,
}
def _validate_safe_expr(node: ast.AST) -> None:
"""Allow only a constrained subset of Python expression nodes."""
if isinstance(node, ast.Expression):
_validate_safe_expr(node.body)
return
if isinstance(node, ast.Constant):
return
if isinstance(node, (ast.List, ast.Tuple, ast.Set)):
for elt in node.elts:
_validate_safe_expr(elt)
return
if isinstance(node, ast.Dict):
for key in node.keys:
if key is not None:
_validate_safe_expr(key)
for value in node.values:
_validate_safe_expr(value)
return
if isinstance(node, ast.BinOp):
if type(node.op) not in _BIN_OPS:
raise ValueError("unsupported operator")
_validate_safe_expr(node.left)
_validate_safe_expr(node.right)
return
if isinstance(node, ast.UnaryOp):
if type(node.op) not in _UNARY_OPS:
raise ValueError("unsupported unary operator")
_validate_safe_expr(node.operand)
return
if isinstance(node, ast.BoolOp):
for value in node.values:
_validate_safe_expr(value)
return
if isinstance(node, ast.Compare):
_validate_safe_expr(node.left)
for comp in node.comparators:
_validate_safe_expr(comp)
return
if isinstance(node, ast.Name):
if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}:
raise ValueError("unknown symbol")
return
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise ValueError("unsafe call target")
if node.func.id not in _SAFE_EVAL_FUNCS:
raise ValueError("function not allowed")
for arg in node.args:
_validate_safe_expr(arg)
for kw in node.keywords:
_validate_safe_expr(kw.value)
return
# Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.)
raise ValueError("unsupported expression")
def _safe_eval_python_expr(expression: str) -> Any:
tree = ast.parse(expression, mode="eval")
_validate_safe_expr(tree)
return eval( # noqa: S307 - validated AST + empty builtins
compile(tree, "<code_interpreter>", "eval"),
{"__builtins__": {}},
dict(_SAFE_EVAL_FUNCS),
)
def _json_safe(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (list, tuple)):
return [_json_safe(v) for v in value]
if isinstance(value, dict):
return {str(k): _json_safe(v) for k, v in value.items()}
return repr(value)
def _safe_eval_expr(expression: str) -> float: def _safe_eval_expr(expression: str) -> float:
tree = ast.parse(expression, mode="eval") tree = ast.parse(expression, mode="eval")
@@ -110,6 +206,51 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
"status": {"code": 422, "message": "invalid_expression"}, "status": {"code": 422, "message": "invalid_expression"},
} }
if tool_name == "current_time":
now = datetime.now().astimezone()
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {
"iso": now.isoformat(),
"local": now.strftime("%Y-%m-%d %H:%M:%S %Z"),
"unix": int(now.timestamp()),
},
"status": {"code": 200, "message": "ok"},
}
if tool_name == "code_interpreter":
code = str(args.get("code") or args.get("expression") or "").strip()
if not code:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "missing code"},
"status": {"code": 400, "message": "bad_request"},
}
if len(code) > 500:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "code too long"},
"status": {"code": 422, "message": "invalid_code"},
}
try:
result = _safe_eval_python_expr(code)
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"code": code, "result": _json_safe(result)},
"status": {"code": 200, "message": "ok"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"code": code, "error": str(exc)},
"status": {"code": 422, "message": "invalid_code"},
}
return { return {
"tool_call_id": call_id, "tool_call_id": call_id,
"name": tool_name or "unknown_tool", "name": tool_name or "unknown_tool",

View File

@@ -0,0 +1,33 @@
import pytest
from core.tool_executor import execute_server_tool
@pytest.mark.asyncio
async def test_code_interpreter_simple_expression():
result = await execute_server_tool(
{
"id": "call_ci_ok",
"function": {
"name": "code_interpreter",
"arguments": '{"code":"sum([1, 2, 3]) + 4"}',
},
}
)
assert result["status"]["code"] == 200
assert result["output"]["result"] == 10
@pytest.mark.asyncio
async def test_code_interpreter_blocks_import_and_io():
result = await execute_server_tool(
{
"id": "call_ci_bad",
"function": {
"name": "code_interpreter",
"arguments": '{"code":"__import__(\\"os\\").system(\\"ls\\")"}',
},
}
)
assert result["status"]["code"] == 422
assert result["status"]["message"] == "invalid_code"

View File

@@ -42,6 +42,8 @@ const renderToolIcon = (icon: string) => {
Images: <Images className={className} />, Images: <Images className={className} />,
CloudSun: <CloudSun className={className} />, CloudSun: <CloudSun className={className} />,
Calendar: <Calendar className={className} />, Calendar: <Calendar className={className} />,
Phone: <Phone className={className} />,
Volume2: <Volume2 className={className} />,
TrendingUp: <TrendingUp className={className} />, TrendingUp: <TrendingUp className={className} />,
Coins: <Coins className={className} />, Coins: <Coins className={className} />,
Terminal: <Terminal className={className} />, Terminal: <Terminal className={className} />,
@@ -1017,6 +1019,37 @@ const TOOL_PARAMETER_HINTS: Record<string, any> = {
}, },
required: ['query'], required: ['query'],
}, },
current_time: {
type: 'object',
properties: {},
required: [],
},
take_phone: {
type: 'object',
properties: {},
required: [],
},
increase_volume: {
type: 'object',
properties: {
step: { type: 'integer', description: 'Volume step, default 1' },
},
required: [],
},
decrease_volume: {
type: 'object',
properties: {
step: { type: 'integer', description: 'Volume step, default 1' },
},
required: [],
},
code_interpreter: {
type: 'object',
properties: {
code: { type: 'string', description: 'Python code' },
},
required: ['code'],
},
}; };
const getDefaultToolParameters = (toolId: string) => const getDefaultToolParameters = (toolId: string) =>
@@ -1158,10 +1191,7 @@ export const DebugDrawer: React.FC<{
return ids.map((id) => { return ids.map((id) => {
const item = byId.get(id); const item = byId.get(id);
const toolId = item?.id || id; const toolId = item?.id || id;
const isClientTool = const isClientTool = (item?.category || 'query') === 'system';
toolId.startsWith('client_') ||
toolId.startsWith('browser_') ||
['camera', 'microphone', 'page_control', 'local_file'].includes(toolId);
return { return {
type: 'function', type: 'function',
executor: isClientTool ? 'client' : 'server', executor: isClientTool ? 'client' : 'server',