Files
AI-VideoAssistant/engine/core/tool_executor.py
2026-02-11 11:04:05 +08:00

260 lines
8.1 KiB
Python

"""Server-side tool execution helpers."""
import ast
import operator
from datetime import datetime
from typing import Any, Dict
_BIN_OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Mod: operator.mod,
}
_UNARY_OPS = {
ast.UAdd: operator.pos,
ast.USub: operator.neg,
}
_SAFE_EVAL_FUNCS = {
"abs": abs,
"round": round,
"min": min,
"max": max,
"sum": sum,
"len": len,
}
def _validate_safe_expr(node: ast.AST) -> None:
"""Allow only a constrained subset of Python expression nodes."""
if isinstance(node, ast.Expression):
_validate_safe_expr(node.body)
return
if isinstance(node, ast.Constant):
return
if isinstance(node, (ast.List, ast.Tuple, ast.Set)):
for elt in node.elts:
_validate_safe_expr(elt)
return
if isinstance(node, ast.Dict):
for key in node.keys:
if key is not None:
_validate_safe_expr(key)
for value in node.values:
_validate_safe_expr(value)
return
if isinstance(node, ast.BinOp):
if type(node.op) not in _BIN_OPS:
raise ValueError("unsupported operator")
_validate_safe_expr(node.left)
_validate_safe_expr(node.right)
return
if isinstance(node, ast.UnaryOp):
if type(node.op) not in _UNARY_OPS:
raise ValueError("unsupported unary operator")
_validate_safe_expr(node.operand)
return
if isinstance(node, ast.BoolOp):
for value in node.values:
_validate_safe_expr(value)
return
if isinstance(node, ast.Compare):
_validate_safe_expr(node.left)
for comp in node.comparators:
_validate_safe_expr(comp)
return
if isinstance(node, ast.Name):
if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}:
raise ValueError("unknown symbol")
return
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise ValueError("unsafe call target")
if node.func.id not in _SAFE_EVAL_FUNCS:
raise ValueError("function not allowed")
for arg in node.args:
_validate_safe_expr(arg)
for kw in node.keywords:
_validate_safe_expr(kw.value)
return
# Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.)
raise ValueError("unsupported expression")
def _safe_eval_python_expr(expression: str) -> Any:
tree = ast.parse(expression, mode="eval")
_validate_safe_expr(tree)
return eval( # noqa: S307 - validated AST + empty builtins
compile(tree, "<code_interpreter>", "eval"),
{"__builtins__": {}},
dict(_SAFE_EVAL_FUNCS),
)
def _json_safe(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (list, tuple)):
return [_json_safe(v) for v in value]
if isinstance(value, dict):
return {str(k): _json_safe(v) for k, v in value.items()}
return repr(value)
def _safe_eval_expr(expression: str) -> float:
tree = ast.parse(expression, mode="eval")
def _eval(node: ast.AST) -> float:
if isinstance(node, ast.Expression):
return _eval(node.body)
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return float(node.value)
if isinstance(node, ast.BinOp):
op = _BIN_OPS.get(type(node.op))
if not op:
raise ValueError("unsupported operator")
return float(op(_eval(node.left), _eval(node.right)))
if isinstance(node, ast.UnaryOp):
op = _UNARY_OPS.get(type(node.op))
if not op:
raise ValueError("unsupported unary operator")
return float(op(_eval(node.operand)))
raise ValueError("unsupported expression")
return _eval(tree)
def _extract_tool_name(tool_call: Dict[str, Any]) -> str:
function_payload = tool_call.get("function")
if isinstance(function_payload, dict):
return str(function_payload.get("name") or "").strip()
return ""
def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
function_payload = tool_call.get("function")
if not isinstance(function_payload, dict):
return {}
raw = function_payload.get("arguments")
if isinstance(raw, dict):
return raw
if not isinstance(raw, str):
return {}
text = raw.strip()
if not text:
return {}
try:
import json
parsed = json.loads(text)
return parsed if isinstance(parsed, dict) else {}
except Exception:
return {}
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
"""Execute a server-side tool and return normalized result payload."""
call_id = str(tool_call.get("id") or "").strip()
tool_name = _extract_tool_name(tool_call)
args = _extract_tool_args(tool_call)
if tool_name == "calculator":
expression = str(args.get("expression") or "").strip()
if not expression:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "missing expression"},
"status": {"code": 400, "message": "bad_request"},
}
if len(expression) > 200:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "error": "expression too long"},
"status": {"code": 422, "message": "invalid_expression"},
}
try:
value = _safe_eval_expr(expression)
if value.is_integer():
value = int(value)
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "result": value},
"status": {"code": 200, "message": "ok"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "error": str(exc)},
"status": {"code": 422, "message": "invalid_expression"},
}
if tool_name == "current_time":
now = datetime.now().astimezone()
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {
"iso": now.isoformat(),
"local": now.strftime("%Y-%m-%d %H:%M:%S %Z"),
"unix": int(now.timestamp()),
},
"status": {"code": 200, "message": "ok"},
}
if tool_name == "code_interpreter":
code = str(args.get("code") or args.get("expression") or "").strip()
if not code:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "missing code"},
"status": {"code": 400, "message": "bad_request"},
}
if len(code) > 500:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "code too long"},
"status": {"code": 422, "message": "invalid_code"},
}
try:
result = _safe_eval_python_expr(code)
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"code": code, "result": _json_safe(result)},
"status": {"code": 200, "message": "ok"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"code": code, "error": str(exc)},
"status": {"code": 422, "message": "invalid_code"},
}
return {
"tool_call_id": call_id,
"name": tool_name or "unknown_tool",
"output": {"message": "server tool not implemented"},
"status": {"code": 501, "message": "not_implemented"},
}