Import tool setting
This commit is contained in:
@@ -111,6 +111,51 @@ class DuplexPipeline:
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
"current_time": {
|
||||
"name": "current_time",
|
||||
"description": "Get current local time",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "code_interpreter",
|
||||
"description": "Execute Python code in a controlled environment",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"code": {"type": "string"}},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
"take_phone": {
|
||||
"name": "take_phone",
|
||||
"description": "Take or answer a phone call",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"increase_volume": {
|
||||
"name": "increase_volume",
|
||||
"description": "Increase speaker volume",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"step": {"type": "integer"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
"decrease_volume": {
|
||||
"name": "decrease_volume",
|
||||
"description": "Decrease speaker volume",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"step": {"type": "integer"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import ast
|
||||
import operator
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@@ -18,6 +19,101 @@ _UNARY_OPS = {
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
_SAFE_EVAL_FUNCS = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
}
|
||||
|
||||
|
||||
def _validate_safe_expr(node: ast.AST) -> None:
|
||||
"""Allow only a constrained subset of Python expression nodes."""
|
||||
if isinstance(node, ast.Expression):
|
||||
_validate_safe_expr(node.body)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Constant):
|
||||
return
|
||||
|
||||
if isinstance(node, (ast.List, ast.Tuple, ast.Set)):
|
||||
for elt in node.elts:
|
||||
_validate_safe_expr(elt)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Dict):
|
||||
for key in node.keys:
|
||||
if key is not None:
|
||||
_validate_safe_expr(key)
|
||||
for value in node.values:
|
||||
_validate_safe_expr(value)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.BinOp):
|
||||
if type(node.op) not in _BIN_OPS:
|
||||
raise ValueError("unsupported operator")
|
||||
_validate_safe_expr(node.left)
|
||||
_validate_safe_expr(node.right)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
if type(node.op) not in _UNARY_OPS:
|
||||
raise ValueError("unsupported unary operator")
|
||||
_validate_safe_expr(node.operand)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.BoolOp):
|
||||
for value in node.values:
|
||||
_validate_safe_expr(value)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Compare):
|
||||
_validate_safe_expr(node.left)
|
||||
for comp in node.comparators:
|
||||
_validate_safe_expr(comp)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}:
|
||||
raise ValueError("unknown symbol")
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Call):
|
||||
if not isinstance(node.func, ast.Name):
|
||||
raise ValueError("unsafe call target")
|
||||
if node.func.id not in _SAFE_EVAL_FUNCS:
|
||||
raise ValueError("function not allowed")
|
||||
for arg in node.args:
|
||||
_validate_safe_expr(arg)
|
||||
for kw in node.keywords:
|
||||
_validate_safe_expr(kw.value)
|
||||
return
|
||||
|
||||
# Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.)
|
||||
raise ValueError("unsupported expression")
|
||||
|
||||
|
||||
def _safe_eval_python_expr(expression: str) -> Any:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
_validate_safe_expr(tree)
|
||||
return eval( # noqa: S307 - validated AST + empty builtins
|
||||
compile(tree, "<code_interpreter>", "eval"),
|
||||
{"__builtins__": {}},
|
||||
dict(_SAFE_EVAL_FUNCS),
|
||||
)
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_json_safe(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(k): _json_safe(v) for k, v in value.items()}
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _safe_eval_expr(expression: str) -> float:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
@@ -110,6 +206,51 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"status": {"code": 422, "message": "invalid_expression"},
|
||||
}
|
||||
|
||||
if tool_name == "current_time":
|
||||
now = datetime.now().astimezone()
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {
|
||||
"iso": now.isoformat(),
|
||||
"local": now.strftime("%Y-%m-%d %H:%M:%S %Z"),
|
||||
"unix": int(now.timestamp()),
|
||||
},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
|
||||
if tool_name == "code_interpreter":
|
||||
code = str(args.get("code") or args.get("expression") or "").strip()
|
||||
if not code:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "missing code"},
|
||||
"status": {"code": 400, "message": "bad_request"},
|
||||
}
|
||||
if len(code) > 500:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "code too long"},
|
||||
"status": {"code": 422, "message": "invalid_code"},
|
||||
}
|
||||
try:
|
||||
result = _safe_eval_python_expr(code)
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"code": code, "result": _json_safe(result)},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"code": code, "error": str(exc)},
|
||||
"status": {"code": 422, "message": "invalid_code"},
|
||||
}
|
||||
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name or "unknown_tool",
|
||||
|
||||
33
engine/tests/test_tool_executor.py
Normal file
33
engine/tests/test_tool_executor.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
|
||||
from core.tool_executor import execute_server_tool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter_simple_expression():
|
||||
result = await execute_server_tool(
|
||||
{
|
||||
"id": "call_ci_ok",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"arguments": '{"code":"sum([1, 2, 3]) + 4"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
assert result["status"]["code"] == 200
|
||||
assert result["output"]["result"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter_blocks_import_and_io():
|
||||
result = await execute_server_tool(
|
||||
{
|
||||
"id": "call_ci_bad",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"arguments": '{"code":"__import__(\\"os\\").system(\\"ls\\")"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
assert result["status"]["code"] == 422
|
||||
assert result["status"]["message"] == "invalid_code"
|
||||
Reference in New Issue
Block a user