Add server tool timeout protection

This commit is contained in:
Xin Wang
2026-02-10 19:17:45 +08:00
parent 6cac24918d
commit 2d7fc2b700
2 changed files with 60 additions and 1 deletions

View File

@@ -58,6 +58,7 @@ class DuplexPipeline:
_SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""}) _SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""})
_MIN_SPLIT_SPOKEN_CHARS = 6 _MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0 _TOOL_WAIT_TIMEOUT_SECONDS = 15.0
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = { _DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
"search": { "search": {
"name": "search", "name": "search",
@@ -984,7 +985,18 @@ class DuplexPipeline:
tool_results.append(result) tool_results.append(result)
continue continue
result = await execute_server_tool(call) try:
result = await asyncio.wait_for(
execute_server_tool(call),
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
result = {
"tool_call_id": call_id,
"name": self._tool_name(call) or "unknown_tool",
"output": {"message": "server tool timeout"},
"status": {"code": 504, "message": "server_tool_timeout"},
}
await self._emit_tool_result(result, source="server") await self._emit_tool_result(result, source="server")
tool_results.append(result) tool_results.append(result)

View File

@@ -252,3 +252,50 @@ async def test_server_calculator_emits_tool_result(monkeypatch):
payload = tool_results[-1].get("result", {}) payload = tool_results[-1].get("result", {})
assert payload.get("status", {}).get("code") == 200 assert payload.get("status", {}).get("code") == 200
assert payload.get("output", {}).get("result") == 3 assert payload.get("output", {}).get("result") == 3
@pytest.mark.asyncio
async def test_server_tool_timeout_emits_504_and_continues(monkeypatch):
async def _slow_execute(_call):
await asyncio.sleep(0.05)
return {
"tool_call_id": "call_slow",
"name": "weather",
"output": {"ok": True},
"status": {"code": 200, "message": "ok"},
}
monkeypatch.setattr("core.duplex_pipeline.execute_server_tool", _slow_execute)
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_slow",
"executor": "server",
"type": "function",
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="timeout fallback."),
LLMStreamEvent(type="done"),
],
],
)
pipeline._SERVER_TOOL_TIMEOUT_SECONDS = 0.01
await pipeline._handle_turn("weather?")
tool_results = [e for e in events if e.get("type") == "assistant.tool_result"]
assert tool_results
payload = tool_results[-1].get("result", {})
assert payload.get("status", {}).get("code") == 504
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "timeout fallback" in finals[-1].get("text", "")