"""Server-side tool execution helpers.""" import asyncio import ast import operator from datetime import datetime from typing import Any, Dict import aiohttp from app.backend_client import fetch_tool_resource _BIN_OPS = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, ast.Div: operator.truediv, ast.Mod: operator.mod, } _UNARY_OPS = { ast.UAdd: operator.pos, ast.USub: operator.neg, } _SAFE_EVAL_FUNCS = { "abs": abs, "round": round, "min": min, "max": max, "sum": sum, "len": len, } def _validate_safe_expr(node: ast.AST) -> None: """Allow only a constrained subset of Python expression nodes.""" if isinstance(node, ast.Expression): _validate_safe_expr(node.body) return if isinstance(node, ast.Constant): return if isinstance(node, (ast.List, ast.Tuple, ast.Set)): for elt in node.elts: _validate_safe_expr(elt) return if isinstance(node, ast.Dict): for key in node.keys: if key is not None: _validate_safe_expr(key) for value in node.values: _validate_safe_expr(value) return if isinstance(node, ast.BinOp): if type(node.op) not in _BIN_OPS: raise ValueError("unsupported operator") _validate_safe_expr(node.left) _validate_safe_expr(node.right) return if isinstance(node, ast.UnaryOp): if type(node.op) not in _UNARY_OPS: raise ValueError("unsupported unary operator") _validate_safe_expr(node.operand) return if isinstance(node, ast.BoolOp): for value in node.values: _validate_safe_expr(value) return if isinstance(node, ast.Compare): _validate_safe_expr(node.left) for comp in node.comparators: _validate_safe_expr(comp) return if isinstance(node, ast.Name): if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}: raise ValueError("unknown symbol") return if isinstance(node, ast.Call): if not isinstance(node.func, ast.Name): raise ValueError("unsafe call target") if node.func.id not in _SAFE_EVAL_FUNCS: raise ValueError("function not allowed") for arg in node.args: _validate_safe_expr(arg) for kw in node.keywords: _validate_safe_expr(kw.value) return # Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.) raise ValueError("unsupported expression") def _safe_eval_python_expr(expression: str) -> Any: tree = ast.parse(expression, mode="eval") _validate_safe_expr(tree) return eval( # noqa: S307 - validated AST + empty builtins compile(tree, "", "eval"), {"__builtins__": {}}, dict(_SAFE_EVAL_FUNCS), ) def _json_safe(value: Any) -> Any: if isinstance(value, (str, int, float, bool)) or value is None: return value if isinstance(value, (list, tuple)): return [_json_safe(v) for v in value] if isinstance(value, dict): return {str(k): _json_safe(v) for k, v in value.items()} return repr(value) def _safe_eval_expr(expression: str) -> float: tree = ast.parse(expression, mode="eval") def _eval(node: ast.AST) -> float: if isinstance(node, ast.Expression): return _eval(node.body) if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): return float(node.value) if isinstance(node, ast.BinOp): op = _BIN_OPS.get(type(node.op)) if not op: raise ValueError("unsupported operator") return float(op(_eval(node.left), _eval(node.right))) if isinstance(node, ast.UnaryOp): op = _UNARY_OPS.get(type(node.op)) if not op: raise ValueError("unsupported unary operator") return float(op(_eval(node.operand))) raise ValueError("unsupported expression") return _eval(tree) def _extract_tool_name(tool_call: Dict[str, Any]) -> str: function_payload = tool_call.get("function") if isinstance(function_payload, dict): return str(function_payload.get("name") or "").strip() return "" def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]: function_payload = tool_call.get("function") if not isinstance(function_payload, dict): return {} raw = function_payload.get("arguments") if isinstance(raw, dict): return raw if not isinstance(raw, str): return {} text = raw.strip() if not text: return {} try: import json parsed = json.loads(text) return parsed if isinstance(parsed, dict) else {} except Exception: return {} async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: """Execute a server-side tool and return normalized result payload.""" call_id = str(tool_call.get("id") or "").strip() tool_name = _extract_tool_name(tool_call) args = _extract_tool_args(tool_call) if tool_name == "calculator": expression = str(args.get("expression") or "").strip() if not expression: return { "tool_call_id": call_id, "name": tool_name, "output": {"error": "missing expression"}, "status": {"code": 400, "message": "bad_request"}, } if len(expression) > 200: return { "tool_call_id": call_id, "name": tool_name, "output": {"expression": expression, "error": "expression too long"}, "status": {"code": 422, "message": "invalid_expression"}, } try: value = _safe_eval_expr(expression) if value.is_integer(): value = int(value) return { "tool_call_id": call_id, "name": tool_name, "output": {"expression": expression, "result": value}, "status": {"code": 200, "message": "ok"}, } except Exception as exc: return { "tool_call_id": call_id, "name": tool_name, "output": {"expression": expression, "error": str(exc)}, "status": {"code": 422, "message": "invalid_expression"}, } if tool_name == "code_interpreter": code = str(args.get("code") or args.get("expression") or "").strip() if not code: return { "tool_call_id": call_id, "name": tool_name, "output": {"error": "missing code"}, "status": {"code": 400, "message": "bad_request"}, } if len(code) > 500: return { "tool_call_id": call_id, "name": tool_name, "output": {"error": "code too long"}, "status": {"code": 422, "message": "invalid_code"}, } try: result = _safe_eval_python_expr(code) return { "tool_call_id": call_id, "name": tool_name, "output": {"code": code, "result": _json_safe(result)}, "status": {"code": 200, "message": "ok"}, } except Exception as exc: return { "tool_call_id": call_id, "name": tool_name, "output": {"code": code, "error": str(exc)}, "status": {"code": 422, "message": "invalid_code"}, } if tool_name == "current_time": now = datetime.now().astimezone() return { "tool_call_id": call_id, "name": tool_name, "output": { "local_time": now.strftime("%Y-%m-%d %H:%M:%S"), "iso": now.isoformat(), "timezone": str(now.tzinfo or ""), "timestamp": int(now.timestamp()), }, "status": {"code": 200, "message": "ok"}, } if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}: resource = await fetch_tool_resource(tool_name) if resource and str(resource.get("category") or "") == "query": method = str(resource.get("http_method") or "GET").strip().upper() if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}: method = "GET" url = str(resource.get("http_url") or "").strip() headers = resource.get("http_headers") if isinstance(resource.get("http_headers"), dict) else {} timeout_ms = resource.get("http_timeout_ms") try: timeout_s = max(1.0, float(timeout_ms) / 1000.0) except Exception: timeout_s = 10.0 if not url: return { "tool_call_id": call_id, "name": tool_name, "output": {"error": "http_url not configured"}, "status": {"code": 422, "message": "invalid_tool_config"}, } request_kwargs: Dict[str, Any] = {} if method in {"GET", "DELETE"}: request_kwargs["params"] = args else: request_kwargs["json"] = args try: timeout = aiohttp.ClientTimeout(total=timeout_s) async with aiohttp.ClientSession(timeout=timeout) as session: async with session.request(method, url, headers=headers, **request_kwargs) as resp: content_type = str(resp.headers.get("Content-Type") or "").lower() if "application/json" in content_type: body: Any = await resp.json() else: body = await resp.text() status_code = int(resp.status) if 200 <= status_code < 300: return { "tool_call_id": call_id, "name": tool_name, "output": { "method": method, "url": url, "status_code": status_code, "response": _json_safe(body), }, "status": {"code": 200, "message": "ok"}, } return { "tool_call_id": call_id, "name": tool_name, "output": { "method": method, "url": url, "status_code": status_code, "response": _json_safe(body), }, "status": {"code": status_code, "message": "http_error"}, } except asyncio.TimeoutError: return { "tool_call_id": call_id, "name": tool_name, "output": {"method": method, "url": url, "error": "request timeout"}, "status": {"code": 504, "message": "http_timeout"}, } except Exception as exc: return { "tool_call_id": call_id, "name": tool_name, "output": {"method": method, "url": url, "error": str(exc)}, "status": {"code": 502, "message": "http_request_failed"}, } return { "tool_call_id": call_id, "name": tool_name or "unknown_tool", "output": {"message": "server tool not implemented"}, "status": {"code": 501, "message": "not_implemented"}, }