Add server tool timeout protection
This commit is contained in:
@@ -58,6 +58,7 @@ class DuplexPipeline:
|
|||||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||||
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
|
_TOOL_WAIT_TIMEOUT_SECONDS = 15.0
|
||||||
|
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
||||||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||||||
"search": {
|
"search": {
|
||||||
"name": "search",
|
"name": "search",
|
||||||
@@ -984,7 +985,18 @@ class DuplexPipeline:
|
|||||||
tool_results.append(result)
|
tool_results.append(result)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
result = await execute_server_tool(call)
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
execute_server_tool(call),
|
||||||
|
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
result = {
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"name": self._tool_name(call) or "unknown_tool",
|
||||||
|
"output": {"message": "server tool timeout"},
|
||||||
|
"status": {"code": 504, "message": "server_tool_timeout"},
|
||||||
|
}
|
||||||
await self._emit_tool_result(result, source="server")
|
await self._emit_tool_result(result, source="server")
|
||||||
tool_results.append(result)
|
tool_results.append(result)
|
||||||
|
|
||||||
|
|||||||
@@ -252,3 +252,50 @@ async def test_server_calculator_emits_tool_result(monkeypatch):
|
|||||||
payload = tool_results[-1].get("result", {})
|
payload = tool_results[-1].get("result", {})
|
||||||
assert payload.get("status", {}).get("code") == 200
|
assert payload.get("status", {}).get("code") == 200
|
||||||
assert payload.get("output", {}).get("result") == 3
|
assert payload.get("output", {}).get("result") == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_tool_timeout_emits_504_and_continues(monkeypatch):
|
||||||
|
async def _slow_execute(_call):
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
return {
|
||||||
|
"tool_call_id": "call_slow",
|
||||||
|
"name": "weather",
|
||||||
|
"output": {"ok": True},
|
||||||
|
"status": {"code": 200, "message": "ok"},
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr("core.duplex_pipeline.execute_server_tool", _slow_execute)
|
||||||
|
|
||||||
|
pipeline, events = _build_pipeline(
|
||||||
|
monkeypatch,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
LLMStreamEvent(
|
||||||
|
type="tool_call",
|
||||||
|
tool_call={
|
||||||
|
"id": "call_slow",
|
||||||
|
"executor": "server",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMStreamEvent(type="done"),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
LLMStreamEvent(type="text_delta", text="timeout fallback."),
|
||||||
|
LLMStreamEvent(type="done"),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipeline._SERVER_TOOL_TIMEOUT_SECONDS = 0.01
|
||||||
|
|
||||||
|
await pipeline._handle_turn("weather?")
|
||||||
|
|
||||||
|
tool_results = [e for e in events if e.get("type") == "assistant.tool_result"]
|
||||||
|
assert tool_results
|
||||||
|
payload = tool_results[-1].get("result", {})
|
||||||
|
assert payload.get("status", {}).get("code") == 504
|
||||||
|
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
||||||
|
assert finals
|
||||||
|
assert "timeout fallback" in finals[-1].get("text", "")
|
||||||
|
|||||||
Reference in New Issue
Block a user