From 6cac24918d504d98c4a91f0f81d182e459e2bd28 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Tue, 10 Feb 2026 19:13:54 +0800 Subject: [PATCH] Now we have server tool and client tool --- engine/core/duplex_pipeline.py | 77 +++++++++++++++++- engine/core/tool_executor.py | 118 ++++++++++++++++++++++++++++ engine/docs/ws_v1_schema.md | 4 +- engine/tests/test_tool_call_flow.py | 35 +++++++++ web/pages/Assistants.tsx | 105 +++++++------------------ 5 files changed, 257 insertions(+), 82 deletions(-) create mode 100644 engine/core/tool_executor.py diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 3b683ab..51f76de 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -22,6 +22,7 @@ from loguru import logger from app.config import settings from core.conversation import ConversationManager, ConversationState from core.events import get_event_bus +from core.tool_executor import execute_server_tool from core.transports import BaseTransport from models.ws_v1 import ev from processors.eou import EouDetector @@ -214,6 +215,7 @@ class DuplexPipeline: self._runtime_knowledge: Dict[str, Any] = {} self._runtime_knowledge_base_id: Optional[str] = None self._runtime_tools: List[Any] = [] + self._runtime_tool_executor: Dict[str, str] = {} self._pending_tool_waiters: Dict[str, asyncio.Future] = {} self._early_tool_results: Dict[str, Dict[str, Any]] = {} self._completed_tool_call_ids: set[str] = set() @@ -270,6 +272,10 @@ class DuplexPipeline: tools_payload = metadata.get("tools") if isinstance(tools_payload, list): self._runtime_tools = tools_payload + self._runtime_tool_executor = self._resolved_tool_executor_map() + elif "tools" in metadata: + self._runtime_tools = [] + self._runtime_tool_executor = {} if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"): self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) @@ -675,6 +681,10 @@ class DuplexPipeline: fn = item.get("function") if isinstance(fn, dict) and fn.get("name"): + fn_name = str(fn.get("name")) + executor = str(item.get("executor") or item.get("run_on") or "").strip().lower() + if executor in {"client", "server"}: + self._runtime_tool_executor[fn_name] = executor schemas.append( { "type": "function", @@ -688,6 +698,10 @@ class DuplexPipeline: continue if item.get("name"): + fn_name = str(item.get("name")) + executor = str(item.get("executor") or item.get("run_on") or "").strip().lower() + if executor in {"client", "server"}: + self._runtime_tool_executor[fn_name] = executor schemas.append( { "type": "function", @@ -700,6 +714,49 @@ class DuplexPipeline: ) return schemas + def _resolved_tool_executor_map(self) -> Dict[str, str]: + result: Dict[str, str] = {} + for item in self._runtime_tools: + if not isinstance(item, dict): + continue + fn = item.get("function") + if isinstance(fn, dict) and fn.get("name"): + name = str(fn.get("name")) + else: + name = str(item.get("name") or "").strip() + if not name: + continue + executor = str(item.get("executor") or item.get("run_on") or "").strip().lower() + if executor in {"client", "server"}: + result[name] = executor + return result + + def _tool_name(self, tool_call: Dict[str, Any]) -> str: + fn = tool_call.get("function") + if isinstance(fn, dict): + return str(fn.get("name") or "").strip() + return "" + + def _tool_executor(self, tool_call: Dict[str, Any]) -> str: + name = self._tool_name(tool_call) + if name and name in self._runtime_tool_executor: + return self._runtime_tool_executor[name] + # Default to server execution unless explicitly marked as client. + return "server" + + async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None: + await self._send_event( + { + **ev( + "assistant.tool_result", + trackId=self.session_id, + source=source, + result=result, + ) + }, + priority=22, + ) + async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None: """Handle client tool execution results.""" if not isinstance(results, list): @@ -807,13 +864,16 @@ class DuplexPipeline: if not tool_call: continue allow_text_output = False - tool_calls.append(tool_call) + executor = self._tool_executor(tool_call) + enriched_tool_call = dict(tool_call) + enriched_tool_call["executor"] = executor + tool_calls.append(enriched_tool_call) await self._send_event( { **ev( "assistant.tool_call", trackId=self.session_id, - tool_call=tool_call, + tool_call=enriched_tool_call, ) }, priority=22, @@ -917,7 +977,16 @@ class DuplexPipeline: call_id = str(call.get("id") or "").strip() if not call_id: continue - tool_results.append(await self._wait_for_single_tool_result(call_id)) + executor = str(call.get("executor") or "server").strip().lower() + if executor == "client": + result = await self._wait_for_single_tool_result(call_id) + await self._emit_tool_result(result, source="client") + tool_results.append(result) + continue + + result = await execute_server_tool(call) + await self._emit_tool_result(result, source="server") + tool_results.append(result) messages = [ *messages, @@ -928,7 +997,7 @@ class DuplexPipeline: LLMMessage( role="system", content=( - "Tool execution results were returned by the client. " + "Tool execution results are available. " "Continue answering the user naturally using these results. " "Do not request the same tool again in this turn.\n" f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n" diff --git a/engine/core/tool_executor.py b/engine/core/tool_executor.py new file mode 100644 index 0000000..e80c73a --- /dev/null +++ b/engine/core/tool_executor.py @@ -0,0 +1,118 @@ +"""Server-side tool execution helpers.""" + +import ast +import operator +from typing import Any, Dict + + +_BIN_OPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Mod: operator.mod, +} + +_UNARY_OPS = { + ast.UAdd: operator.pos, + ast.USub: operator.neg, +} + + +def _safe_eval_expr(expression: str) -> float: + tree = ast.parse(expression, mode="eval") + + def _eval(node: ast.AST) -> float: + if isinstance(node, ast.Expression): + return _eval(node.body) + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return float(node.value) + if isinstance(node, ast.BinOp): + op = _BIN_OPS.get(type(node.op)) + if not op: + raise ValueError("unsupported operator") + return float(op(_eval(node.left), _eval(node.right))) + if isinstance(node, ast.UnaryOp): + op = _UNARY_OPS.get(type(node.op)) + if not op: + raise ValueError("unsupported unary operator") + return float(op(_eval(node.operand))) + raise ValueError("unsupported expression") + + return _eval(tree) + + +def _extract_tool_name(tool_call: Dict[str, Any]) -> str: + function_payload = tool_call.get("function") + if isinstance(function_payload, dict): + return str(function_payload.get("name") or "").strip() + return "" + + +def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]: + function_payload = tool_call.get("function") + if not isinstance(function_payload, dict): + return {} + raw = function_payload.get("arguments") + if isinstance(raw, dict): + return raw + if not isinstance(raw, str): + return {} + text = raw.strip() + if not text: + return {} + try: + import json + + parsed = json.loads(text) + return parsed if isinstance(parsed, dict) else {} + except Exception: + return {} + + +async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: + """Execute a server-side tool and return normalized result payload.""" + call_id = str(tool_call.get("id") or "").strip() + tool_name = _extract_tool_name(tool_call) + args = _extract_tool_args(tool_call) + + if tool_name == "calculator": + expression = str(args.get("expression") or "").strip() + if not expression: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"error": "missing expression"}, + "status": {"code": 400, "message": "bad_request"}, + } + if len(expression) > 200: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"expression": expression, "error": "expression too long"}, + "status": {"code": 422, "message": "invalid_expression"}, + } + try: + value = _safe_eval_expr(expression) + if value.is_integer(): + value = int(value) + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"expression": expression, "result": value}, + "status": {"code": 200, "message": "ok"}, + } + except Exception as exc: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"expression": expression, "error": str(exc)}, + "status": {"code": 422, "message": "invalid_expression"}, + } + + return { + "tool_call_id": call_id, + "name": tool_name or "unknown_tool", + "output": {"message": "server tool not implemented"}, + "status": {"code": 501, "message": "not_implemented"}, + } diff --git a/engine/docs/ws_v1_schema.md b/engine/docs/ws_v1_schema.md index c84ce5a..0d21432 100644 --- a/engine/docs/ws_v1_schema.md +++ b/engine/docs/ws_v1_schema.md @@ -161,7 +161,9 @@ Common events: - `assistant.response.final` - Fields: `trackId`, `text` - `assistant.tool_call` - - Fields: `trackId`, `tool_call` + - Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`) +- `assistant.tool_result` + - Fields: `trackId`, `source`, `result` - `output.audio.start` - Fields: `trackId` - `output.audio.end` diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index bdf2889..6f70828 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -136,6 +136,7 @@ async def test_turn_with_tool_call_then_results(monkeypatch): type="tool_call", tool_call={ "id": "call_ok", + "executor": "client", "type": "function", "function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"}, }, @@ -183,6 +184,7 @@ async def test_turn_with_tool_call_timeout(monkeypatch): type="tool_call", tool_call={ "id": "call_timeout", + "executor": "client", "type": "function", "function": {"name": "search", "arguments": "{\"query\":\"x\"}"}, }, @@ -217,3 +219,36 @@ async def test_duplicate_tool_results_are_ignored(monkeypatch): result = await pipeline._wait_for_single_tool_result("call_dup") assert result.get("output", {}).get("value") == 1 + + +@pytest.mark.asyncio +async def test_server_calculator_emits_tool_result(monkeypatch): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_calc", + "executor": "server", + "type": "function", + "function": {"name": "calculator", "arguments": "{\"expression\":\"1+2\"}"}, + }, + ), + LLMStreamEvent(type="done"), + ], + [ + LLMStreamEvent(type="text_delta", text="done."), + LLMStreamEvent(type="done"), + ], + ], + ) + + await pipeline._handle_turn("calc") + + tool_results = [e for e in events if e.get("type") == "assistant.tool_result"] + assert tool_results + payload = tool_results[-1].get("result", {}) + assert payload.get("status", {}).get("code") == 200 + assert payload.get("output", {}).get("result") == 3 diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index 5ea3311..87646a7 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -1000,37 +1000,6 @@ const TOOL_PARAMETER_HINTS: Record = { const getDefaultToolParameters = (toolId: string) => TOOL_PARAMETER_HINTS[toolId] || { type: 'object', properties: {} }; -const parseToolArgs = (raw: unknown): Record => { - if (raw && typeof raw === 'object') return raw as Record; - if (typeof raw !== 'string') return {}; - const text = raw.trim(); - if (!text) return {}; - try { - const parsed = JSON.parse(text); - return parsed && typeof parsed === 'object' ? parsed : {}; - } catch { - return {}; - } -}; - -const evaluateCalculatorExpression = (expression: string): { ok: true; result: number } | { ok: false; error: string } => { - const expr = String(expression || '').trim(); - if (!expr) return { ok: false, error: 'empty expression' }; - if (expr.length > 200) return { ok: false, error: 'expression too long' }; - if (!/^[0-9+\-*/().%\s]+$/.test(expr)) return { ok: false, error: 'invalid characters' }; - if (expr.includes('**') || expr.includes('//')) return { ok: false, error: 'unsupported operator' }; - - try { - const value = Function(`"use strict"; return (${expr});`)(); - if (typeof value !== 'number' || !Number.isFinite(value)) { - return { ok: false, error: 'expression is not finite number' }; - } - return { ok: true, result: value }; - } catch { - return { ok: false, error: 'invalid expression' }; - } -}; - // Stable transcription log so the scroll container is not recreated on every render (avoids scroll jumping) const TranscriptionLog: React.FC<{ scrollRef: React.RefObject; @@ -1165,12 +1134,18 @@ export const DebugDrawer: React.FC<{ const byId = new Map(tools.map((t) => [t.id, t])); return ids.map((id) => { const item = byId.get(id); + const toolId = item?.id || id; + const isClientTool = + toolId.startsWith('client_') || + toolId.startsWith('browser_') || + ['camera', 'microphone', 'page_control', 'local_file'].includes(toolId); return { type: 'function', + executor: isClientTool ? 'client' : 'server', function: { - name: item?.id || id, + name: toolId, description: item?.description || item?.name || id, - parameters: getDefaultToolParameters(item?.id || id), + parameters: getDefaultToolParameters(toolId), }, }; }); @@ -1760,64 +1735,25 @@ export const DebugDrawer: React.FC<{ const toolCall = payload?.tool_call || {}; const toolCallId = String(toolCall?.id || '').trim(); const toolName = String(toolCall?.function?.name || toolCall?.name || 'unknown_tool'); + const executor = String(toolCall?.executor || 'server').toLowerCase(); const rawArgs = String(toolCall?.function?.arguments || ''); const argText = rawArgs.length > 160 ? `${rawArgs.slice(0, 160)}...` : rawArgs; setMessages((prev) => [ ...prev, { role: 'tool', - text: `call ${toolName}${argText ? ` args=${argText}` : ''}`, + text: `call ${toolName} executor=${executor}${argText ? ` args=${argText}` : ''}`, }, ]); - if (toolCallId && ws.readyState === WebSocket.OPEN) { - const argsObj = parseToolArgs(toolCall?.function?.arguments); - const normalizedToolName = toolName.trim().toLowerCase(); - let resultPayload: any = { + if (executor === 'client' && toolCallId && ws.readyState === WebSocket.OPEN) { + const resultPayload: any = { tool_call_id: toolCallId, name: toolName, output: { - message: 'Tool execution is not implemented in debug web client', + message: 'Client tool execution is not implemented in debug web client', }, status: { code: 501, message: 'not_implemented' }, }; - - const isCalculatorTool = - normalizedToolName === 'calculator' || - normalizedToolName.endsWith('.calculator') || - normalizedToolName.includes('calculator'); - - if (isCalculatorTool) { - const expression = String(argsObj.expression || argsObj.input || '').trim(); - if (!expression) { - resultPayload = { - tool_call_id: toolCallId, - name: toolName, - output: { error: 'missing expression' }, - status: { code: 400, message: 'bad_request' }, - }; - } else { - const calc = evaluateCalculatorExpression(expression); - if (calc.ok) { - resultPayload = { - tool_call_id: toolCallId, - name: toolName, - output: { - expression, - result: calc.result, - }, - status: { code: 200, message: 'ok' }, - }; - } else { - resultPayload = { - tool_call_id: toolCallId, - name: toolName, - output: { expression, error: calc.error }, - status: { code: 422, message: 'invalid_expression' }, - }; - } - } - } - ws.send( JSON.stringify({ type: 'tool_call.results', @@ -1841,6 +1777,21 @@ export const DebugDrawer: React.FC<{ return; } + if (type === 'assistant.tool_result') { + const result = payload?.result || {}; + const toolName = String(result?.name || 'unknown_tool'); + const statusCode = Number(result?.status?.code || 500); + const statusMessage = String(result?.status?.message || 'error'); + const source = String(payload?.source || 'server'); + const output = result?.output; + const resultText = + statusCode === 200 + ? `result ${toolName} source=${source} ${JSON.stringify(output)}` + : `result ${toolName} source=${source} status=${statusCode} ${statusMessage}`; + setMessages((prev) => [...prev, { role: 'tool', text: resultText }]); + return; + } + if (type === 'session.started') { wsReadyRef.current = true; setWsStatus('ready');