Fix fastgpt client tool 3 rounds bugs
This commit is contained in:
@@ -30,6 +30,18 @@ from runtime.events import get_event_bus, reset_event_bus
|
||||
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
def _inactivity_deadline(
|
||||
*,
|
||||
last_received_at: float,
|
||||
inactivity_timeout_sec: int,
|
||||
pending_client_tool_deadline: Optional[float] = None,
|
||||
) -> float:
|
||||
deadline = float(last_received_at) + float(inactivity_timeout_sec)
|
||||
if pending_client_tool_deadline is not None:
|
||||
deadline = max(deadline, float(pending_client_tool_deadline))
|
||||
return deadline
|
||||
|
||||
|
||||
async def heartbeat_and_timeout_task(
|
||||
transport: BaseTransport,
|
||||
session: Session,
|
||||
@@ -48,8 +60,22 @@ async def heartbeat_and_timeout_task(
|
||||
if transport.is_closed:
|
||||
break
|
||||
now = time.monotonic()
|
||||
if now - last_received_at[0] > inactivity_timeout_sec:
|
||||
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
|
||||
pending_client_tool_deadline = session.pipeline.pending_client_tool_deadline()
|
||||
idle_deadline = _inactivity_deadline(
|
||||
last_received_at=last_received_at[0],
|
||||
inactivity_timeout_sec=inactivity_timeout_sec,
|
||||
pending_client_tool_deadline=pending_client_tool_deadline,
|
||||
)
|
||||
if now > idle_deadline:
|
||||
if pending_client_tool_deadline is not None and pending_client_tool_deadline >= (
|
||||
last_received_at[0] + inactivity_timeout_sec
|
||||
):
|
||||
logger.info(
|
||||
"Session {}: no message before pending client tool deadline, closing",
|
||||
session_id,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
|
||||
await session.cleanup()
|
||||
break
|
||||
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
||||
|
||||
@@ -73,6 +73,8 @@ class DuplexPipeline:
|
||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||
_TOOL_WAIT_TIMEOUT_SECONDS = 60.0
|
||||
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
||||
_MAX_LLM_ROUNDS = 3
|
||||
_MAX_PROVIDER_MANAGED_ROUNDS = 24
|
||||
TRACK_AUDIO_IN = "audio_in"
|
||||
TRACK_AUDIO_OUT = "audio_out"
|
||||
TRACK_CONTROL = "control"
|
||||
@@ -408,6 +410,7 @@ class DuplexPipeline:
|
||||
self._runtime_tool_display_names: Dict[str, str] = {}
|
||||
self._runtime_tool_wait_for_response: Dict[str, bool] = {}
|
||||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||
self._pending_tool_deadlines: Dict[str, float] = {}
|
||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||
self._completed_tool_call_ids: set[str] = set()
|
||||
self._pending_client_tool_call_ids: set[str] = set()
|
||||
@@ -2236,6 +2239,7 @@ class DuplexPipeline:
|
||||
future = loop.create_future()
|
||||
self._pending_tool_waiters[call_id] = future
|
||||
timeout = timeout_seconds if isinstance(timeout_seconds, (int, float)) and timeout_seconds > 0 else self._TOOL_WAIT_TIMEOUT_SECONDS
|
||||
self._pending_tool_deadlines[call_id] = time.monotonic() + timeout
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -2247,8 +2251,14 @@ class DuplexPipeline:
|
||||
}
|
||||
finally:
|
||||
self._pending_tool_waiters.pop(call_id, None)
|
||||
self._pending_tool_deadlines.pop(call_id, None)
|
||||
self._pending_client_tool_call_ids.discard(call_id)
|
||||
|
||||
def pending_client_tool_deadline(self) -> Optional[float]:
|
||||
if not self._pending_tool_deadlines:
|
||||
return None
|
||||
return max(self._pending_tool_deadlines.values())
|
||||
|
||||
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||||
if isinstance(item, LLMStreamEvent):
|
||||
return item
|
||||
@@ -2289,7 +2299,8 @@ class DuplexPipeline:
|
||||
messages = self.conversation.get_messages()
|
||||
if system_context and system_context.strip():
|
||||
messages = [*messages, LLMMessage(role="system", content=system_context.strip())]
|
||||
max_rounds = 3
|
||||
llm_rounds = 0
|
||||
provider_rounds_remaining = self._MAX_PROVIDER_MANAGED_ROUNDS
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
self._is_bot_speaking = True
|
||||
@@ -2300,10 +2311,27 @@ class DuplexPipeline:
|
||||
self._pending_llm_delta = ""
|
||||
self._last_llm_delta_emit_ms = 0.0
|
||||
pending_provider_stream = None
|
||||
for _ in range(max_rounds):
|
||||
while True:
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
|
||||
if pending_provider_stream is not None:
|
||||
if provider_rounds_remaining <= 0:
|
||||
logger.warning(
|
||||
"Provider-managed tool chain exceeded {} rounds; ending turn early",
|
||||
self._MAX_PROVIDER_MANAGED_ROUNDS,
|
||||
)
|
||||
break
|
||||
provider_rounds_remaining -= 1
|
||||
else:
|
||||
if llm_rounds >= self._MAX_LLM_ROUNDS:
|
||||
logger.warning(
|
||||
"LLM tool planning exceeded {} rounds; ending turn early",
|
||||
self._MAX_LLM_ROUNDS,
|
||||
)
|
||||
break
|
||||
llm_rounds += 1
|
||||
|
||||
sentence_buffer = ""
|
||||
pending_punctuation = ""
|
||||
round_response = ""
|
||||
|
||||
13
engine/tests/test_session_timeout.py
Normal file
13
engine/tests/test_session_timeout.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from app.main import _inactivity_deadline
|
||||
|
||||
|
||||
def test_inactivity_deadline_uses_default_timeout_without_pending_tool():
|
||||
assert _inactivity_deadline(last_received_at=100.0, inactivity_timeout_sec=60) == 160.0
|
||||
|
||||
|
||||
def test_inactivity_deadline_extends_while_waiting_for_client_tool():
|
||||
assert _inactivity_deadline(
|
||||
last_received_at=100.0,
|
||||
inactivity_timeout_sec=60,
|
||||
pending_client_tool_deadline=340.0,
|
||||
) == 340.0
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
@@ -821,6 +822,66 @@ class _FakeResumableLLM:
|
||||
yield LLMStreamEvent(type="done")
|
||||
|
||||
|
||||
class _FakeChainedResumableLLM:
|
||||
def __init__(self, call_ids: List[str], *, timeout_ms: int = 300000):
|
||||
self.call_ids = call_ids
|
||||
self.timeout_ms = timeout_ms
|
||||
self.generate_stream_calls = 0
|
||||
self.resumed_results: List[Dict[str, Any]] = []
|
||||
|
||||
def _tool_call_event(self, call_id: str) -> LLMStreamEvent:
|
||||
return LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": call_id,
|
||||
"executor": "client",
|
||||
"wait_for_response": True,
|
||||
"timeout_ms": self.timeout_ms,
|
||||
"display_name": f"Collect {call_id}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "fastgpt.interactive",
|
||||
"arguments": json.dumps(
|
||||
{
|
||||
"provider": "fastgpt",
|
||||
"version": "fastgpt_interactive_v1",
|
||||
"interaction": {
|
||||
"type": "userInput",
|
||||
"title": "",
|
||||
"description": f"Prompt for {call_id}",
|
||||
"prompt": f"Prompt for {call_id}",
|
||||
"form": [{"name": "result", "label": "result", "input_type": "input"}],
|
||||
"options": [],
|
||||
},
|
||||
"context": {"chat_id": "fastgpt_chat_chain"},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
async def generate(self, _messages, temperature=0.7, max_tokens=None):
|
||||
return ""
|
||||
|
||||
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
|
||||
self.generate_stream_calls += 1
|
||||
yield self._tool_call_event(self.call_ids[0])
|
||||
yield LLMStreamEvent(type="done")
|
||||
|
||||
def handles_client_tool(self, tool_name: str) -> bool:
|
||||
return tool_name == "fastgpt.interactive"
|
||||
|
||||
async def resume_after_client_tool_result(self, tool_call_id: str, result: Dict[str, Any]):
|
||||
self.resumed_results.append({"tool_call_id": tool_call_id, "result": dict(result)})
|
||||
next_index = len(self.resumed_results)
|
||||
if next_index < len(self.call_ids):
|
||||
yield self._tool_call_event(self.call_ids[next_index])
|
||||
else:
|
||||
yield LLMStreamEvent(type="text_delta", text="completed after third interactive input.")
|
||||
yield LLMStreamEvent(type="done")
|
||||
|
||||
|
||||
def _build_pipeline_with_custom_llm(monkeypatch, llm_service) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
|
||||
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
|
||||
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
|
||||
@@ -903,3 +964,72 @@ async def test_fastgpt_provider_managed_tool_timeout_stops_without_generic_tool_
|
||||
assert not finals
|
||||
assert llm.generate_stream_calls == 1
|
||||
assert llm.resumed_results == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fastgpt_provider_managed_tool_chain_can_continue_after_third_result(monkeypatch):
|
||||
llm = _FakeChainedResumableLLM(["call_fastgpt_1", "call_fastgpt_2", "call_fastgpt_3"])
|
||||
pipeline, events = _build_pipeline_with_custom_llm(monkeypatch, llm)
|
||||
pipeline.apply_runtime_overrides({"output": {"mode": "text"}})
|
||||
|
||||
task = asyncio.create_task(pipeline._handle_turn("start chained fastgpt"))
|
||||
|
||||
expected_call_ids = ["call_fastgpt_1", "call_fastgpt_2", "call_fastgpt_3"]
|
||||
for idx, call_id in enumerate(expected_call_ids, start=1):
|
||||
for _ in range(200):
|
||||
seen_call_ids = [event.get("tool_call_id") for event in events if event.get("type") == "assistant.tool_call"]
|
||||
if call_id in seen_call_ids:
|
||||
break
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
await pipeline.handle_tool_call_results(
|
||||
[
|
||||
{
|
||||
"tool_call_id": call_id,
|
||||
"name": "fastgpt.interactive",
|
||||
"output": {
|
||||
"action": "submit",
|
||||
"result": {"type": "userInput", "fields": {"result": f"value-{idx}"}},
|
||||
},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
await task
|
||||
|
||||
finals = [event for event in events if event.get("type") == "assistant.response.final"]
|
||||
assert finals
|
||||
assert "completed after third interactive input" in finals[-1].get("text", "")
|
||||
assert llm.generate_stream_calls == 1
|
||||
assert len(llm.resumed_results) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_client_tool_deadline_tracks_waiting_result(monkeypatch):
|
||||
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
|
||||
|
||||
waiter = asyncio.create_task(pipeline._wait_for_single_tool_result("call_deadline", timeout_seconds=30))
|
||||
for _ in range(50):
|
||||
deadline = pipeline.pending_client_tool_deadline()
|
||||
if deadline is not None:
|
||||
break
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
deadline = pipeline.pending_client_tool_deadline()
|
||||
assert deadline is not None
|
||||
assert deadline > time.monotonic() + 25
|
||||
|
||||
await pipeline.handle_tool_call_results(
|
||||
[
|
||||
{
|
||||
"tool_call_id": "call_deadline",
|
||||
"name": "fastgpt.interactive",
|
||||
"output": {"action": "submit", "result": {"type": "userInput", "fields": {"name": "Alice"}}},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
]
|
||||
)
|
||||
await waiter
|
||||
|
||||
assert pipeline.pending_client_tool_deadline() is None
|
||||
|
||||
Reference in New Issue
Block a user