Import tool setting
This commit is contained in:
@@ -72,6 +72,15 @@ TOOL_REGISTRY = {
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"current_time": {
|
||||
"name": "当前时间",
|
||||
"description": "获取当前本地时间",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "代码执行",
|
||||
"description": "安全地执行Python代码",
|
||||
@@ -83,6 +92,37 @@ TOOL_REGISTRY = {
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
"take_phone": {
|
||||
"name": "接听电话",
|
||||
"description": "执行接听电话命令",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
"increase_volume": {
|
||||
"name": "调高音量",
|
||||
"description": "提升设备音量",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"step": {"type": "integer", "description": "调整步进,默认1"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
"decrease_volume": {
|
||||
"name": "调低音量",
|
||||
"description": "降低设备音量",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"step": {"type": "integer", "description": "调整步进,默认1"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
TOOL_CATEGORY_MAP = {
|
||||
@@ -90,8 +130,12 @@ TOOL_CATEGORY_MAP = {
|
||||
"weather": "query",
|
||||
"translate": "query",
|
||||
"knowledge": "query",
|
||||
"calculator": "system",
|
||||
"code_interpreter": "system",
|
||||
"calculator": "query",
|
||||
"current_time": "query",
|
||||
"code_interpreter": "query",
|
||||
"take_phone": "system",
|
||||
"increase_volume": "system",
|
||||
"decrease_volume": "system",
|
||||
}
|
||||
|
||||
TOOL_ICON_MAP = {
|
||||
@@ -99,26 +143,50 @@ TOOL_ICON_MAP = {
|
||||
"weather": "CloudSun",
|
||||
"translate": "Globe",
|
||||
"knowledge": "Box",
|
||||
"current_time": "Calendar",
|
||||
"calculator": "Terminal",
|
||||
"code_interpreter": "Terminal",
|
||||
"take_phone": "Phone",
|
||||
"increase_volume": "Volume2",
|
||||
"decrease_volume": "Volume2",
|
||||
}
|
||||
|
||||
def _seed_default_tools_if_empty(db: Session) -> None:
|
||||
"""Seed default tools into DB when tool_resources is empty."""
|
||||
if db.query(ToolResource).count() > 0:
|
||||
return
|
||||
|
||||
def _sync_default_tools(db: Session) -> None:
|
||||
"""Ensure built-in tools exist and keep system tool metadata aligned."""
|
||||
changed = False
|
||||
for tool_id, payload in TOOL_REGISTRY.items():
|
||||
row = db.query(ToolResource).filter(ToolResource.id == tool_id).first()
|
||||
category = TOOL_CATEGORY_MAP.get(tool_id, "system")
|
||||
icon = TOOL_ICON_MAP.get(tool_id, "Wrench")
|
||||
if not row:
|
||||
db.add(ToolResource(
|
||||
id=tool_id,
|
||||
user_id=1,
|
||||
name=payload.get("name", tool_id),
|
||||
description=payload.get("description", ""),
|
||||
category=TOOL_CATEGORY_MAP.get(tool_id, "system"),
|
||||
icon=TOOL_ICON_MAP.get(tool_id, "Wrench"),
|
||||
category=category,
|
||||
icon=icon,
|
||||
enabled=True,
|
||||
is_system=True,
|
||||
))
|
||||
changed = True
|
||||
continue
|
||||
if row.is_system:
|
||||
new_name = payload.get("name", row.name)
|
||||
new_description = payload.get("description", row.description)
|
||||
if row.name != new_name:
|
||||
row.name = new_name
|
||||
changed = True
|
||||
if row.description != new_description:
|
||||
row.description = new_description
|
||||
changed = True
|
||||
if row.category != category:
|
||||
row.category = category
|
||||
changed = True
|
||||
if row.icon != icon:
|
||||
row.icon = icon
|
||||
changed = True
|
||||
if changed:
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -147,7 +215,7 @@ def list_tool_resources(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
_sync_default_tools(db)
|
||||
query = db.query(ToolResource)
|
||||
if not include_system:
|
||||
query = query.filter(ToolResource.is_system == False)
|
||||
@@ -163,7 +231,7 @@ def list_tool_resources(
|
||||
@router.get("/resources/{id}", response_model=ToolResourceOut)
|
||||
def get_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个工具资源详情。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
_sync_default_tools(db)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
@@ -173,7 +241,7 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
@router.post("/resources", response_model=ToolResourceOut)
|
||||
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
|
||||
"""创建自定义工具资源。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
_sync_default_tools(db)
|
||||
candidate_id = (data.id or "").strip()
|
||||
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
|
||||
raise HTTPException(status_code=400, detail="Tool ID already exists")
|
||||
@@ -197,7 +265,7 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
|
||||
@router.put("/resources/{id}", response_model=ToolResourceOut)
|
||||
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
|
||||
"""更新工具资源。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
_sync_default_tools(db)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
@@ -215,7 +283,7 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
|
||||
@router.delete("/resources/{id}")
|
||||
def delete_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
"""删除工具资源。"""
|
||||
_seed_default_tools_if_empty(db)
|
||||
_sync_default_tools(db)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
|
||||
@@ -111,6 +111,51 @@ class DuplexPipeline:
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
"current_time": {
|
||||
"name": "current_time",
|
||||
"description": "Get current local time",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "code_interpreter",
|
||||
"description": "Execute Python code in a controlled environment",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"code": {"type": "string"}},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
"take_phone": {
|
||||
"name": "take_phone",
|
||||
"description": "Take or answer a phone call",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"increase_volume": {
|
||||
"name": "increase_volume",
|
||||
"description": "Increase speaker volume",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"step": {"type": "integer"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"decrease_volume": {
|
||||
"name": "decrease_volume",
|
||||
"description": "Decrease speaker volume",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"step": {"type": "integer"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import ast
|
||||
import operator
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@@ -18,6 +19,101 @@ _UNARY_OPS = {
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
_SAFE_EVAL_FUNCS = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
}
|
||||
|
||||
|
||||
def _validate_safe_expr(node: ast.AST) -> None:
|
||||
"""Allow only a constrained subset of Python expression nodes."""
|
||||
if isinstance(node, ast.Expression):
|
||||
_validate_safe_expr(node.body)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Constant):
|
||||
return
|
||||
|
||||
if isinstance(node, (ast.List, ast.Tuple, ast.Set)):
|
||||
for elt in node.elts:
|
||||
_validate_safe_expr(elt)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Dict):
|
||||
for key in node.keys:
|
||||
if key is not None:
|
||||
_validate_safe_expr(key)
|
||||
for value in node.values:
|
||||
_validate_safe_expr(value)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.BinOp):
|
||||
if type(node.op) not in _BIN_OPS:
|
||||
raise ValueError("unsupported operator")
|
||||
_validate_safe_expr(node.left)
|
||||
_validate_safe_expr(node.right)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
if type(node.op) not in _UNARY_OPS:
|
||||
raise ValueError("unsupported unary operator")
|
||||
_validate_safe_expr(node.operand)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.BoolOp):
|
||||
for value in node.values:
|
||||
_validate_safe_expr(value)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Compare):
|
||||
_validate_safe_expr(node.left)
|
||||
for comp in node.comparators:
|
||||
_validate_safe_expr(comp)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}:
|
||||
raise ValueError("unknown symbol")
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Call):
|
||||
if not isinstance(node.func, ast.Name):
|
||||
raise ValueError("unsafe call target")
|
||||
if node.func.id not in _SAFE_EVAL_FUNCS:
|
||||
raise ValueError("function not allowed")
|
||||
for arg in node.args:
|
||||
_validate_safe_expr(arg)
|
||||
for kw in node.keywords:
|
||||
_validate_safe_expr(kw.value)
|
||||
return
|
||||
|
||||
# Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.)
|
||||
raise ValueError("unsupported expression")
|
||||
|
||||
|
||||
def _safe_eval_python_expr(expression: str) -> Any:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
_validate_safe_expr(tree)
|
||||
return eval( # noqa: S307 - validated AST + empty builtins
|
||||
compile(tree, "<code_interpreter>", "eval"),
|
||||
{"__builtins__": {}},
|
||||
dict(_SAFE_EVAL_FUNCS),
|
||||
)
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_json_safe(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(k): _json_safe(v) for k, v in value.items()}
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _safe_eval_expr(expression: str) -> float:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
@@ -110,6 +206,51 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"status": {"code": 422, "message": "invalid_expression"},
|
||||
}
|
||||
|
||||
if tool_name == "current_time":
|
||||
now = datetime.now().astimezone()
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {
|
||||
"iso": now.isoformat(),
|
||||
"local": now.strftime("%Y-%m-%d %H:%M:%S %Z"),
|
||||
"unix": int(now.timestamp()),
|
||||
},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
|
||||
if tool_name == "code_interpreter":
|
||||
code = str(args.get("code") or args.get("expression") or "").strip()
|
||||
if not code:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "missing code"},
|
||||
"status": {"code": 400, "message": "bad_request"},
|
||||
}
|
||||
if len(code) > 500:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "code too long"},
|
||||
"status": {"code": 422, "message": "invalid_code"},
|
||||
}
|
||||
try:
|
||||
result = _safe_eval_python_expr(code)
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"code": code, "result": _json_safe(result)},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"code": code, "error": str(exc)},
|
||||
"status": {"code": 422, "message": "invalid_code"},
|
||||
}
|
||||
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name or "unknown_tool",
|
||||
|
||||
33
engine/tests/test_tool_executor.py
Normal file
33
engine/tests/test_tool_executor.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
|
||||
from core.tool_executor import execute_server_tool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter_simple_expression():
|
||||
result = await execute_server_tool(
|
||||
{
|
||||
"id": "call_ci_ok",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"arguments": '{"code":"sum([1, 2, 3]) + 4"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
assert result["status"]["code"] == 200
|
||||
assert result["output"]["result"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter_blocks_import_and_io():
|
||||
result = await execute_server_tool(
|
||||
{
|
||||
"id": "call_ci_bad",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"arguments": '{"code":"__import__(\\"os\\").system(\\"ls\\")"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
assert result["status"]["code"] == 422
|
||||
assert result["status"]["message"] == "invalid_code"
|
||||
@@ -42,6 +42,8 @@ const renderToolIcon = (icon: string) => {
|
||||
Images: <Images className={className} />,
|
||||
CloudSun: <CloudSun className={className} />,
|
||||
Calendar: <Calendar className={className} />,
|
||||
Phone: <Phone className={className} />,
|
||||
Volume2: <Volume2 className={className} />,
|
||||
TrendingUp: <TrendingUp className={className} />,
|
||||
Coins: <Coins className={className} />,
|
||||
Terminal: <Terminal className={className} />,
|
||||
@@ -1017,6 +1019,37 @@ const TOOL_PARAMETER_HINTS: Record<string, any> = {
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
current_time: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
required: [],
|
||||
},
|
||||
take_phone: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
required: [],
|
||||
},
|
||||
increase_volume: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
step: { type: 'integer', description: 'Volume step, default 1' },
|
||||
},
|
||||
required: [],
|
||||
},
|
||||
decrease_volume: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
step: { type: 'integer', description: 'Volume step, default 1' },
|
||||
},
|
||||
required: [],
|
||||
},
|
||||
code_interpreter: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
code: { type: 'string', description: 'Python code' },
|
||||
},
|
||||
required: ['code'],
|
||||
},
|
||||
};
|
||||
|
||||
const getDefaultToolParameters = (toolId: string) =>
|
||||
@@ -1158,10 +1191,7 @@ export const DebugDrawer: React.FC<{
|
||||
return ids.map((id) => {
|
||||
const item = byId.get(id);
|
||||
const toolId = item?.id || id;
|
||||
const isClientTool =
|
||||
toolId.startsWith('client_') ||
|
||||
toolId.startsWith('browser_') ||
|
||||
['camera', 'microphone', 'page_control', 'local_file'].includes(toolId);
|
||||
const isClientTool = (item?.category || 'query') === 'system';
|
||||
return {
|
||||
type: 'function',
|
||||
executor: isClientTool ? 'client' : 'server',
|
||||
|
||||
Reference in New Issue
Block a user