From 9b9fbf432f5e6afe24e5665b76dde7b7d7db1263 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Wed, 11 Mar 2026 11:33:27 +0800 Subject: [PATCH] Fix fastgpt client tool 3 rounds bugs --- engine/app/main.py | 30 ++++++- engine/runtime/pipeline/duplex.py | 32 ++++++- engine/tests/test_session_timeout.py | 13 +++ engine/tests/test_tool_call_flow.py | 130 +++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 4 deletions(-) create mode 100644 engine/tests/test_session_timeout.py diff --git a/engine/app/main.py b/engine/app/main.py index 09ffa1d..d93875e 100644 --- a/engine/app/main.py +++ b/engine/app/main.py @@ -30,6 +30,18 @@ from runtime.events import get_event_bus, reset_event_bus _HEARTBEAT_CHECK_INTERVAL_SEC = 5 +def _inactivity_deadline( + *, + last_received_at: float, + inactivity_timeout_sec: int, + pending_client_tool_deadline: Optional[float] = None, +) -> float: + deadline = float(last_received_at) + float(inactivity_timeout_sec) + if pending_client_tool_deadline is not None: + deadline = max(deadline, float(pending_client_tool_deadline)) + return deadline + + async def heartbeat_and_timeout_task( transport: BaseTransport, session: Session, @@ -48,8 +60,22 @@ async def heartbeat_and_timeout_task( if transport.is_closed: break now = time.monotonic() - if now - last_received_at[0] > inactivity_timeout_sec: - logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing") + pending_client_tool_deadline = session.pipeline.pending_client_tool_deadline() + idle_deadline = _inactivity_deadline( + last_received_at=last_received_at[0], + inactivity_timeout_sec=inactivity_timeout_sec, + pending_client_tool_deadline=pending_client_tool_deadline, + ) + if now > idle_deadline: + if pending_client_tool_deadline is not None and pending_client_tool_deadline >= ( + last_received_at[0] + inactivity_timeout_sec + ): + logger.info( + "Session {}: no message before pending client tool deadline, closing", + session_id, + ) + else: + logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing") await session.cleanup() break if now - last_heartbeat_at[0] >= heartbeat_interval_sec: diff --git a/engine/runtime/pipeline/duplex.py b/engine/runtime/pipeline/duplex.py index 743c945..17c7dd8 100644 --- a/engine/runtime/pipeline/duplex.py +++ b/engine/runtime/pipeline/duplex.py @@ -73,6 +73,8 @@ class DuplexPipeline: _MIN_SPLIT_SPOKEN_CHARS = 6 _TOOL_WAIT_TIMEOUT_SECONDS = 60.0 _SERVER_TOOL_TIMEOUT_SECONDS = 15.0 + _MAX_LLM_ROUNDS = 3 + _MAX_PROVIDER_MANAGED_ROUNDS = 24 TRACK_AUDIO_IN = "audio_in" TRACK_AUDIO_OUT = "audio_out" TRACK_CONTROL = "control" @@ -408,6 +410,7 @@ class DuplexPipeline: self._runtime_tool_display_names: Dict[str, str] = {} self._runtime_tool_wait_for_response: Dict[str, bool] = {} self._pending_tool_waiters: Dict[str, asyncio.Future] = {} + self._pending_tool_deadlines: Dict[str, float] = {} self._early_tool_results: Dict[str, Dict[str, Any]] = {} self._completed_tool_call_ids: set[str] = set() self._pending_client_tool_call_ids: set[str] = set() @@ -2236,6 +2239,7 @@ class DuplexPipeline: future = loop.create_future() self._pending_tool_waiters[call_id] = future timeout = timeout_seconds if isinstance(timeout_seconds, (int, float)) and timeout_seconds > 0 else self._TOOL_WAIT_TIMEOUT_SECONDS + self._pending_tool_deadlines[call_id] = time.monotonic() + timeout try: return await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError: @@ -2247,8 +2251,14 @@ class DuplexPipeline: } finally: self._pending_tool_waiters.pop(call_id, None) + self._pending_tool_deadlines.pop(call_id, None) self._pending_client_tool_call_ids.discard(call_id) + def pending_client_tool_deadline(self) -> Optional[float]: + if not self._pending_tool_deadlines: + return None + return max(self._pending_tool_deadlines.values()) + def _normalize_stream_event(self, item: Any) -> LLMStreamEvent: if isinstance(item, LLMStreamEvent): return item @@ -2289,7 +2299,8 @@ class DuplexPipeline: messages = self.conversation.get_messages() if system_context and system_context.strip(): messages = [*messages, LLMMessage(role="system", content=system_context.strip())] - max_rounds = 3 + llm_rounds = 0 + provider_rounds_remaining = self._MAX_PROVIDER_MANAGED_ROUNDS await self.conversation.start_assistant_turn() self._is_bot_speaking = True @@ -2300,10 +2311,27 @@ class DuplexPipeline: self._pending_llm_delta = "" self._last_llm_delta_emit_ms = 0.0 pending_provider_stream = None - for _ in range(max_rounds): + while True: if self._interrupt_event.is_set(): break + if pending_provider_stream is not None: + if provider_rounds_remaining <= 0: + logger.warning( + "Provider-managed tool chain exceeded {} rounds; ending turn early", + self._MAX_PROVIDER_MANAGED_ROUNDS, + ) + break + provider_rounds_remaining -= 1 + else: + if llm_rounds >= self._MAX_LLM_ROUNDS: + logger.warning( + "LLM tool planning exceeded {} rounds; ending turn early", + self._MAX_LLM_ROUNDS, + ) + break + llm_rounds += 1 + sentence_buffer = "" pending_punctuation = "" round_response = "" diff --git a/engine/tests/test_session_timeout.py b/engine/tests/test_session_timeout.py new file mode 100644 index 0000000..54905d8 --- /dev/null +++ b/engine/tests/test_session_timeout.py @@ -0,0 +1,13 @@ +from app.main import _inactivity_deadline + + +def test_inactivity_deadline_uses_default_timeout_without_pending_tool(): + assert _inactivity_deadline(last_received_at=100.0, inactivity_timeout_sec=60) == 160.0 + + +def test_inactivity_deadline_extends_while_waiting_for_client_tool(): + assert _inactivity_deadline( + last_received_at=100.0, + inactivity_timeout_sec=60, + pending_client_tool_deadline=340.0, + ) == 340.0 diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index 3550d20..820cd8d 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -1,5 +1,6 @@ import asyncio import json +import time from typing import Any, Dict, List import pytest @@ -821,6 +822,66 @@ class _FakeResumableLLM: yield LLMStreamEvent(type="done") +class _FakeChainedResumableLLM: + def __init__(self, call_ids: List[str], *, timeout_ms: int = 300000): + self.call_ids = call_ids + self.timeout_ms = timeout_ms + self.generate_stream_calls = 0 + self.resumed_results: List[Dict[str, Any]] = [] + + def _tool_call_event(self, call_id: str) -> LLMStreamEvent: + return LLMStreamEvent( + type="tool_call", + tool_call={ + "id": call_id, + "executor": "client", + "wait_for_response": True, + "timeout_ms": self.timeout_ms, + "display_name": f"Collect {call_id}", + "type": "function", + "function": { + "name": "fastgpt.interactive", + "arguments": json.dumps( + { + "provider": "fastgpt", + "version": "fastgpt_interactive_v1", + "interaction": { + "type": "userInput", + "title": "", + "description": f"Prompt for {call_id}", + "prompt": f"Prompt for {call_id}", + "form": [{"name": "result", "label": "result", "input_type": "input"}], + "options": [], + }, + "context": {"chat_id": "fastgpt_chat_chain"}, + }, + ensure_ascii=False, + ), + }, + }, + ) + + async def generate(self, _messages, temperature=0.7, max_tokens=None): + return "" + + async def generate_stream(self, _messages, temperature=0.7, max_tokens=None): + self.generate_stream_calls += 1 + yield self._tool_call_event(self.call_ids[0]) + yield LLMStreamEvent(type="done") + + def handles_client_tool(self, tool_name: str) -> bool: + return tool_name == "fastgpt.interactive" + + async def resume_after_client_tool_result(self, tool_call_id: str, result: Dict[str, Any]): + self.resumed_results.append({"tool_call_id": tool_call_id, "result": dict(result)}) + next_index = len(self.resumed_results) + if next_index < len(self.call_ids): + yield self._tool_call_event(self.call_ids[next_index]) + else: + yield LLMStreamEvent(type="text_delta", text="completed after third interactive input.") + yield LLMStreamEvent(type="done") + + def _build_pipeline_with_custom_llm(monkeypatch, llm_service) -> tuple[DuplexPipeline, List[Dict[str, Any]]]: monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD) monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor) @@ -903,3 +964,72 @@ async def test_fastgpt_provider_managed_tool_timeout_stops_without_generic_tool_ assert not finals assert llm.generate_stream_calls == 1 assert llm.resumed_results == [] + + +@pytest.mark.asyncio +async def test_fastgpt_provider_managed_tool_chain_can_continue_after_third_result(monkeypatch): + llm = _FakeChainedResumableLLM(["call_fastgpt_1", "call_fastgpt_2", "call_fastgpt_3"]) + pipeline, events = _build_pipeline_with_custom_llm(monkeypatch, llm) + pipeline.apply_runtime_overrides({"output": {"mode": "text"}}) + + task = asyncio.create_task(pipeline._handle_turn("start chained fastgpt")) + + expected_call_ids = ["call_fastgpt_1", "call_fastgpt_2", "call_fastgpt_3"] + for idx, call_id in enumerate(expected_call_ids, start=1): + for _ in range(200): + seen_call_ids = [event.get("tool_call_id") for event in events if event.get("type") == "assistant.tool_call"] + if call_id in seen_call_ids: + break + await asyncio.sleep(0.005) + + await pipeline.handle_tool_call_results( + [ + { + "tool_call_id": call_id, + "name": "fastgpt.interactive", + "output": { + "action": "submit", + "result": {"type": "userInput", "fields": {"result": f"value-{idx}"}}, + }, + "status": {"code": 200, "message": "ok"}, + } + ] + ) + + await task + + finals = [event for event in events if event.get("type") == "assistant.response.final"] + assert finals + assert "completed after third interactive input" in finals[-1].get("text", "") + assert llm.generate_stream_calls == 1 + assert len(llm.resumed_results) == 3 + + +@pytest.mark.asyncio +async def test_pending_client_tool_deadline_tracks_waiting_result(monkeypatch): + pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) + + waiter = asyncio.create_task(pipeline._wait_for_single_tool_result("call_deadline", timeout_seconds=30)) + for _ in range(50): + deadline = pipeline.pending_client_tool_deadline() + if deadline is not None: + break + await asyncio.sleep(0.001) + + deadline = pipeline.pending_client_tool_deadline() + assert deadline is not None + assert deadline > time.monotonic() + 25 + + await pipeline.handle_tool_call_results( + [ + { + "tool_call_id": "call_deadline", + "name": "fastgpt.interactive", + "output": {"action": "submit", "result": {"type": "userInput", "fields": {"name": "Alice"}}}, + "status": {"code": 200, "message": "ok"}, + } + ] + ) + await waiter + + assert pipeline.pending_client_tool_deadline() is None