This commit is contained in:
Xin Wang
2026-02-27 17:38:48 +08:00
2 changed files with 133 additions and 29 deletions

View File

@@ -1673,6 +1673,13 @@ class DuplexPipeline:
def _tool_wait_for_response(self, tool_name: str) -> bool:
return bool(self._runtime_tool_wait_for_response.get(tool_name, False))
def _tool_wait_timeout_seconds(self, tool_name: str) -> float:
if tool_name == "text_msg_prompt" and self._tool_wait_for_response(tool_name):
# Keep engine wait slightly longer than UI auto-close (120s)
# to avoid race where engine times out before client emits timeout result.
return 125.0
return self._TOOL_WAIT_TIMEOUT_SECONDS
def _tool_executor(self, tool_call: Dict[str, Any]) -> str:
name = self._tool_name(tool_call)
if name and name in self._runtime_tool_executor:
@@ -1792,10 +1799,17 @@ class DuplexPipeline:
self._early_tool_results[call_id] = item
self._completed_tool_call_ids.add(call_id)
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
async def _wait_for_single_tool_result(
self,
call_id: str,
*,
tool_name: str = "unknown_tool",
timeout_seconds: Optional[float] = None,
) -> Dict[str, Any]:
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
return {
"tool_call_id": call_id,
"name": tool_name,
"status": {"code": 208, "message": "tool_call result already handled"},
"output": "",
}
@@ -1806,12 +1820,14 @@ class DuplexPipeline:
loop = asyncio.get_running_loop()
future = loop.create_future()
self._pending_tool_waiters[call_id] = future
wait_timeout = float(timeout_seconds if isinstance(timeout_seconds, (int, float)) and timeout_seconds > 0 else self._TOOL_WAIT_TIMEOUT_SECONDS)
try:
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
return await asyncio.wait_for(future, timeout=wait_timeout)
except asyncio.TimeoutError:
self._completed_tool_call_ids.add(call_id)
return {
"tool_call_id": call_id,
"name": tool_name,
"status": {"code": 504, "message": "tool_call timeout"},
"output": "",
}
@@ -1900,7 +1916,9 @@ class DuplexPipeline:
tool_id = self._tool_id_for_name(tool_name)
tool_display_name = self._tool_display_name(tool_name) or tool_name
wait_for_response = self._tool_wait_for_response(tool_name)
wait_timeout_seconds = self._tool_wait_timeout_seconds(tool_name)
enriched_tool_call["wait_for_response"] = wait_for_response
enriched_tool_call["wait_timeout_ms"] = int(wait_timeout_seconds * 1000)
call_id = str(enriched_tool_call.get("id") or "").strip()
fn_payload = (
dict(enriched_tool_call.get("function"))
@@ -1935,9 +1953,10 @@ class DuplexPipeline:
tool_id=tool_id,
tool_display_name=tool_display_name,
wait_for_response=wait_for_response,
wait_timeout_ms=int(wait_timeout_seconds * 1000),
arguments=tool_arguments,
executor=executor,
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
timeout_ms=int(wait_timeout_seconds * 1000),
tool_call=enriched_tool_call,
)
},
@@ -2075,7 +2094,11 @@ class DuplexPipeline:
tool_id = self._tool_id_for_name(tool_name)
logger.info(f"[Tool] execute start name={tool_name} call_id={call_id} executor={executor}")
if executor == "client":
result = await self._wait_for_single_tool_result(call_id)
result = await self._wait_for_single_tool_result(
call_id,
tool_name=tool_name,
timeout_seconds=self._tool_wait_timeout_seconds(tool_name),
)
await self._emit_tool_result(result, source="client")
tool_results.append(result)
continue