Fix fastgpt client tool 3 rounds bugs

This commit is contained in:
Xin Wang
2026-03-11 11:33:27 +08:00
parent f3612a710d
commit 9b9fbf432f
4 changed files with 201 additions and 4 deletions

View File

@@ -30,6 +30,18 @@ from runtime.events import get_event_bus, reset_event_bus
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
def _inactivity_deadline(
*,
last_received_at: float,
inactivity_timeout_sec: int,
pending_client_tool_deadline: Optional[float] = None,
) -> float:
deadline = float(last_received_at) + float(inactivity_timeout_sec)
if pending_client_tool_deadline is not None:
deadline = max(deadline, float(pending_client_tool_deadline))
return deadline
async def heartbeat_and_timeout_task(
transport: BaseTransport,
session: Session,
@@ -48,8 +60,22 @@ async def heartbeat_and_timeout_task(
if transport.is_closed:
break
now = time.monotonic()
if now - last_received_at[0] > inactivity_timeout_sec:
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
pending_client_tool_deadline = session.pipeline.pending_client_tool_deadline()
idle_deadline = _inactivity_deadline(
last_received_at=last_received_at[0],
inactivity_timeout_sec=inactivity_timeout_sec,
pending_client_tool_deadline=pending_client_tool_deadline,
)
if now > idle_deadline:
if pending_client_tool_deadline is not None and pending_client_tool_deadline >= (
last_received_at[0] + inactivity_timeout_sec
):
logger.info(
"Session {}: no message before pending client tool deadline, closing",
session_id,
)
else:
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
await session.cleanup()
break
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:

View File

@@ -73,6 +73,8 @@ class DuplexPipeline:
_MIN_SPLIT_SPOKEN_CHARS = 6
_TOOL_WAIT_TIMEOUT_SECONDS = 60.0
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
_MAX_LLM_ROUNDS = 3
_MAX_PROVIDER_MANAGED_ROUNDS = 24
TRACK_AUDIO_IN = "audio_in"
TRACK_AUDIO_OUT = "audio_out"
TRACK_CONTROL = "control"
@@ -408,6 +410,7 @@ class DuplexPipeline:
self._runtime_tool_display_names: Dict[str, str] = {}
self._runtime_tool_wait_for_response: Dict[str, bool] = {}
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
self._pending_tool_deadlines: Dict[str, float] = {}
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
self._completed_tool_call_ids: set[str] = set()
self._pending_client_tool_call_ids: set[str] = set()
@@ -2236,6 +2239,7 @@ class DuplexPipeline:
future = loop.create_future()
self._pending_tool_waiters[call_id] = future
timeout = timeout_seconds if isinstance(timeout_seconds, (int, float)) and timeout_seconds > 0 else self._TOOL_WAIT_TIMEOUT_SECONDS
self._pending_tool_deadlines[call_id] = time.monotonic() + timeout
try:
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError:
@@ -2247,8 +2251,14 @@ class DuplexPipeline:
}
finally:
self._pending_tool_waiters.pop(call_id, None)
self._pending_tool_deadlines.pop(call_id, None)
self._pending_client_tool_call_ids.discard(call_id)
def pending_client_tool_deadline(self) -> Optional[float]:
if not self._pending_tool_deadlines:
return None
return max(self._pending_tool_deadlines.values())
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
if isinstance(item, LLMStreamEvent):
return item
@@ -2289,7 +2299,8 @@ class DuplexPipeline:
messages = self.conversation.get_messages()
if system_context and system_context.strip():
messages = [*messages, LLMMessage(role="system", content=system_context.strip())]
max_rounds = 3
llm_rounds = 0
provider_rounds_remaining = self._MAX_PROVIDER_MANAGED_ROUNDS
await self.conversation.start_assistant_turn()
self._is_bot_speaking = True
@@ -2300,10 +2311,27 @@ class DuplexPipeline:
self._pending_llm_delta = ""
self._last_llm_delta_emit_ms = 0.0
pending_provider_stream = None
for _ in range(max_rounds):
while True:
if self._interrupt_event.is_set():
break
if pending_provider_stream is not None:
if provider_rounds_remaining <= 0:
logger.warning(
"Provider-managed tool chain exceeded {} rounds; ending turn early",
self._MAX_PROVIDER_MANAGED_ROUNDS,
)
break
provider_rounds_remaining -= 1
else:
if llm_rounds >= self._MAX_LLM_ROUNDS:
logger.warning(
"LLM tool planning exceeded {} rounds; ending turn early",
self._MAX_LLM_ROUNDS,
)
break
llm_rounds += 1
sentence_buffer = ""
pending_punctuation = ""
round_response = ""

View File

@@ -0,0 +1,13 @@
from app.main import _inactivity_deadline
def test_inactivity_deadline_uses_default_timeout_without_pending_tool():
assert _inactivity_deadline(last_received_at=100.0, inactivity_timeout_sec=60) == 160.0
def test_inactivity_deadline_extends_while_waiting_for_client_tool():
assert _inactivity_deadline(
last_received_at=100.0,
inactivity_timeout_sec=60,
pending_client_tool_deadline=340.0,
) == 340.0

View File

@@ -1,5 +1,6 @@
import asyncio
import json
import time
from typing import Any, Dict, List
import pytest
@@ -821,6 +822,66 @@ class _FakeResumableLLM:
yield LLMStreamEvent(type="done")
class _FakeChainedResumableLLM:
def __init__(self, call_ids: List[str], *, timeout_ms: int = 300000):
self.call_ids = call_ids
self.timeout_ms = timeout_ms
self.generate_stream_calls = 0
self.resumed_results: List[Dict[str, Any]] = []
def _tool_call_event(self, call_id: str) -> LLMStreamEvent:
return LLMStreamEvent(
type="tool_call",
tool_call={
"id": call_id,
"executor": "client",
"wait_for_response": True,
"timeout_ms": self.timeout_ms,
"display_name": f"Collect {call_id}",
"type": "function",
"function": {
"name": "fastgpt.interactive",
"arguments": json.dumps(
{
"provider": "fastgpt",
"version": "fastgpt_interactive_v1",
"interaction": {
"type": "userInput",
"title": "",
"description": f"Prompt for {call_id}",
"prompt": f"Prompt for {call_id}",
"form": [{"name": "result", "label": "result", "input_type": "input"}],
"options": [],
},
"context": {"chat_id": "fastgpt_chat_chain"},
},
ensure_ascii=False,
),
},
},
)
async def generate(self, _messages, temperature=0.7, max_tokens=None):
return ""
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
self.generate_stream_calls += 1
yield self._tool_call_event(self.call_ids[0])
yield LLMStreamEvent(type="done")
def handles_client_tool(self, tool_name: str) -> bool:
return tool_name == "fastgpt.interactive"
async def resume_after_client_tool_result(self, tool_call_id: str, result: Dict[str, Any]):
self.resumed_results.append({"tool_call_id": tool_call_id, "result": dict(result)})
next_index = len(self.resumed_results)
if next_index < len(self.call_ids):
yield self._tool_call_event(self.call_ids[next_index])
else:
yield LLMStreamEvent(type="text_delta", text="completed after third interactive input.")
yield LLMStreamEvent(type="done")
def _build_pipeline_with_custom_llm(monkeypatch, llm_service) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
@@ -903,3 +964,72 @@ async def test_fastgpt_provider_managed_tool_timeout_stops_without_generic_tool_
assert not finals
assert llm.generate_stream_calls == 1
assert llm.resumed_results == []
@pytest.mark.asyncio
async def test_fastgpt_provider_managed_tool_chain_can_continue_after_third_result(monkeypatch):
llm = _FakeChainedResumableLLM(["call_fastgpt_1", "call_fastgpt_2", "call_fastgpt_3"])
pipeline, events = _build_pipeline_with_custom_llm(monkeypatch, llm)
pipeline.apply_runtime_overrides({"output": {"mode": "text"}})
task = asyncio.create_task(pipeline._handle_turn("start chained fastgpt"))
expected_call_ids = ["call_fastgpt_1", "call_fastgpt_2", "call_fastgpt_3"]
for idx, call_id in enumerate(expected_call_ids, start=1):
for _ in range(200):
seen_call_ids = [event.get("tool_call_id") for event in events if event.get("type") == "assistant.tool_call"]
if call_id in seen_call_ids:
break
await asyncio.sleep(0.005)
await pipeline.handle_tool_call_results(
[
{
"tool_call_id": call_id,
"name": "fastgpt.interactive",
"output": {
"action": "submit",
"result": {"type": "userInput", "fields": {"result": f"value-{idx}"}},
},
"status": {"code": 200, "message": "ok"},
}
]
)
await task
finals = [event for event in events if event.get("type") == "assistant.response.final"]
assert finals
assert "completed after third interactive input" in finals[-1].get("text", "")
assert llm.generate_stream_calls == 1
assert len(llm.resumed_results) == 3
@pytest.mark.asyncio
async def test_pending_client_tool_deadline_tracks_waiting_result(monkeypatch):
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
waiter = asyncio.create_task(pipeline._wait_for_single_tool_result("call_deadline", timeout_seconds=30))
for _ in range(50):
deadline = pipeline.pending_client_tool_deadline()
if deadline is not None:
break
await asyncio.sleep(0.001)
deadline = pipeline.pending_client_tool_deadline()
assert deadline is not None
assert deadline > time.monotonic() + 25
await pipeline.handle_tool_call_results(
[
{
"tool_call_id": "call_deadline",
"name": "fastgpt.interactive",
"output": {"action": "submit", "result": {"type": "userInput", "fields": {"name": "Alice"}}},
"status": {"code": 200, "message": "ok"},
}
]
)
await waiter
assert pipeline.pending_client_tool_deadline() is None