Fix fastgpt client tool 3 rounds bugs
This commit is contained in:
@@ -30,6 +30,18 @@ from runtime.events import get_event_bus, reset_event_bus
|
|||||||
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _inactivity_deadline(
|
||||||
|
*,
|
||||||
|
last_received_at: float,
|
||||||
|
inactivity_timeout_sec: int,
|
||||||
|
pending_client_tool_deadline: Optional[float] = None,
|
||||||
|
) -> float:
|
||||||
|
deadline = float(last_received_at) + float(inactivity_timeout_sec)
|
||||||
|
if pending_client_tool_deadline is not None:
|
||||||
|
deadline = max(deadline, float(pending_client_tool_deadline))
|
||||||
|
return deadline
|
||||||
|
|
||||||
|
|
||||||
async def heartbeat_and_timeout_task(
|
async def heartbeat_and_timeout_task(
|
||||||
transport: BaseTransport,
|
transport: BaseTransport,
|
||||||
session: Session,
|
session: Session,
|
||||||
@@ -48,8 +60,22 @@ async def heartbeat_and_timeout_task(
|
|||||||
if transport.is_closed:
|
if transport.is_closed:
|
||||||
break
|
break
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
if now - last_received_at[0] > inactivity_timeout_sec:
|
pending_client_tool_deadline = session.pipeline.pending_client_tool_deadline()
|
||||||
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
|
idle_deadline = _inactivity_deadline(
|
||||||
|
last_received_at=last_received_at[0],
|
||||||
|
inactivity_timeout_sec=inactivity_timeout_sec,
|
||||||
|
pending_client_tool_deadline=pending_client_tool_deadline,
|
||||||
|
)
|
||||||
|
if now > idle_deadline:
|
||||||
|
if pending_client_tool_deadline is not None and pending_client_tool_deadline >= (
|
||||||
|
last_received_at[0] + inactivity_timeout_sec
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Session {}: no message before pending client tool deadline, closing",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
|
||||||
await session.cleanup()
|
await session.cleanup()
|
||||||
break
|
break
|
||||||
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
||||||
|
|||||||
@@ -73,6 +73,8 @@ class DuplexPipeline:
|
|||||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||||
_TOOL_WAIT_TIMEOUT_SECONDS = 60.0
|
_TOOL_WAIT_TIMEOUT_SECONDS = 60.0
|
||||||
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
||||||
|
_MAX_LLM_ROUNDS = 3
|
||||||
|
_MAX_PROVIDER_MANAGED_ROUNDS = 24
|
||||||
TRACK_AUDIO_IN = "audio_in"
|
TRACK_AUDIO_IN = "audio_in"
|
||||||
TRACK_AUDIO_OUT = "audio_out"
|
TRACK_AUDIO_OUT = "audio_out"
|
||||||
TRACK_CONTROL = "control"
|
TRACK_CONTROL = "control"
|
||||||
@@ -408,6 +410,7 @@ class DuplexPipeline:
|
|||||||
self._runtime_tool_display_names: Dict[str, str] = {}
|
self._runtime_tool_display_names: Dict[str, str] = {}
|
||||||
self._runtime_tool_wait_for_response: Dict[str, bool] = {}
|
self._runtime_tool_wait_for_response: Dict[str, bool] = {}
|
||||||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||||||
|
self._pending_tool_deadlines: Dict[str, float] = {}
|
||||||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||||||
self._completed_tool_call_ids: set[str] = set()
|
self._completed_tool_call_ids: set[str] = set()
|
||||||
self._pending_client_tool_call_ids: set[str] = set()
|
self._pending_client_tool_call_ids: set[str] = set()
|
||||||
@@ -2236,6 +2239,7 @@ class DuplexPipeline:
|
|||||||
future = loop.create_future()
|
future = loop.create_future()
|
||||||
self._pending_tool_waiters[call_id] = future
|
self._pending_tool_waiters[call_id] = future
|
||||||
timeout = timeout_seconds if isinstance(timeout_seconds, (int, float)) and timeout_seconds > 0 else self._TOOL_WAIT_TIMEOUT_SECONDS
|
timeout = timeout_seconds if isinstance(timeout_seconds, (int, float)) and timeout_seconds > 0 else self._TOOL_WAIT_TIMEOUT_SECONDS
|
||||||
|
self._pending_tool_deadlines[call_id] = time.monotonic() + timeout
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(future, timeout=timeout)
|
return await asyncio.wait_for(future, timeout=timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -2247,8 +2251,14 @@ class DuplexPipeline:
|
|||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
self._pending_tool_waiters.pop(call_id, None)
|
self._pending_tool_waiters.pop(call_id, None)
|
||||||
|
self._pending_tool_deadlines.pop(call_id, None)
|
||||||
self._pending_client_tool_call_ids.discard(call_id)
|
self._pending_client_tool_call_ids.discard(call_id)
|
||||||
|
|
||||||
|
def pending_client_tool_deadline(self) -> Optional[float]:
|
||||||
|
if not self._pending_tool_deadlines:
|
||||||
|
return None
|
||||||
|
return max(self._pending_tool_deadlines.values())
|
||||||
|
|
||||||
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||||||
if isinstance(item, LLMStreamEvent):
|
if isinstance(item, LLMStreamEvent):
|
||||||
return item
|
return item
|
||||||
@@ -2289,7 +2299,8 @@ class DuplexPipeline:
|
|||||||
messages = self.conversation.get_messages()
|
messages = self.conversation.get_messages()
|
||||||
if system_context and system_context.strip():
|
if system_context and system_context.strip():
|
||||||
messages = [*messages, LLMMessage(role="system", content=system_context.strip())]
|
messages = [*messages, LLMMessage(role="system", content=system_context.strip())]
|
||||||
max_rounds = 3
|
llm_rounds = 0
|
||||||
|
provider_rounds_remaining = self._MAX_PROVIDER_MANAGED_ROUNDS
|
||||||
|
|
||||||
await self.conversation.start_assistant_turn()
|
await self.conversation.start_assistant_turn()
|
||||||
self._is_bot_speaking = True
|
self._is_bot_speaking = True
|
||||||
@@ -2300,10 +2311,27 @@ class DuplexPipeline:
|
|||||||
self._pending_llm_delta = ""
|
self._pending_llm_delta = ""
|
||||||
self._last_llm_delta_emit_ms = 0.0
|
self._last_llm_delta_emit_ms = 0.0
|
||||||
pending_provider_stream = None
|
pending_provider_stream = None
|
||||||
for _ in range(max_rounds):
|
while True:
|
||||||
if self._interrupt_event.is_set():
|
if self._interrupt_event.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if pending_provider_stream is not None:
|
||||||
|
if provider_rounds_remaining <= 0:
|
||||||
|
logger.warning(
|
||||||
|
"Provider-managed tool chain exceeded {} rounds; ending turn early",
|
||||||
|
self._MAX_PROVIDER_MANAGED_ROUNDS,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
provider_rounds_remaining -= 1
|
||||||
|
else:
|
||||||
|
if llm_rounds >= self._MAX_LLM_ROUNDS:
|
||||||
|
logger.warning(
|
||||||
|
"LLM tool planning exceeded {} rounds; ending turn early",
|
||||||
|
self._MAX_LLM_ROUNDS,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
llm_rounds += 1
|
||||||
|
|
||||||
sentence_buffer = ""
|
sentence_buffer = ""
|
||||||
pending_punctuation = ""
|
pending_punctuation = ""
|
||||||
round_response = ""
|
round_response = ""
|
||||||
|
|||||||
13
engine/tests/test_session_timeout.py
Normal file
13
engine/tests/test_session_timeout.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from app.main import _inactivity_deadline
|
||||||
|
|
||||||
|
|
||||||
|
def test_inactivity_deadline_uses_default_timeout_without_pending_tool():
|
||||||
|
assert _inactivity_deadline(last_received_at=100.0, inactivity_timeout_sec=60) == 160.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_inactivity_deadline_extends_while_waiting_for_client_tool():
|
||||||
|
assert _inactivity_deadline(
|
||||||
|
last_received_at=100.0,
|
||||||
|
inactivity_timeout_sec=60,
|
||||||
|
pending_client_tool_deadline=340.0,
|
||||||
|
) == 340.0
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -821,6 +822,66 @@ class _FakeResumableLLM:
|
|||||||
yield LLMStreamEvent(type="done")
|
yield LLMStreamEvent(type="done")
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeChainedResumableLLM:
|
||||||
|
def __init__(self, call_ids: List[str], *, timeout_ms: int = 300000):
|
||||||
|
self.call_ids = call_ids
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
self.generate_stream_calls = 0
|
||||||
|
self.resumed_results: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _tool_call_event(self, call_id: str) -> LLMStreamEvent:
|
||||||
|
return LLMStreamEvent(
|
||||||
|
type="tool_call",
|
||||||
|
tool_call={
|
||||||
|
"id": call_id,
|
||||||
|
"executor": "client",
|
||||||
|
"wait_for_response": True,
|
||||||
|
"timeout_ms": self.timeout_ms,
|
||||||
|
"display_name": f"Collect {call_id}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "fastgpt.interactive",
|
||||||
|
"arguments": json.dumps(
|
||||||
|
{
|
||||||
|
"provider": "fastgpt",
|
||||||
|
"version": "fastgpt_interactive_v1",
|
||||||
|
"interaction": {
|
||||||
|
"type": "userInput",
|
||||||
|
"title": "",
|
||||||
|
"description": f"Prompt for {call_id}",
|
||||||
|
"prompt": f"Prompt for {call_id}",
|
||||||
|
"form": [{"name": "result", "label": "result", "input_type": "input"}],
|
||||||
|
"options": [],
|
||||||
|
},
|
||||||
|
"context": {"chat_id": "fastgpt_chat_chain"},
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate(self, _messages, temperature=0.7, max_tokens=None):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
|
||||||
|
self.generate_stream_calls += 1
|
||||||
|
yield self._tool_call_event(self.call_ids[0])
|
||||||
|
yield LLMStreamEvent(type="done")
|
||||||
|
|
||||||
|
def handles_client_tool(self, tool_name: str) -> bool:
|
||||||
|
return tool_name == "fastgpt.interactive"
|
||||||
|
|
||||||
|
async def resume_after_client_tool_result(self, tool_call_id: str, result: Dict[str, Any]):
|
||||||
|
self.resumed_results.append({"tool_call_id": tool_call_id, "result": dict(result)})
|
||||||
|
next_index = len(self.resumed_results)
|
||||||
|
if next_index < len(self.call_ids):
|
||||||
|
yield self._tool_call_event(self.call_ids[next_index])
|
||||||
|
else:
|
||||||
|
yield LLMStreamEvent(type="text_delta", text="completed after third interactive input.")
|
||||||
|
yield LLMStreamEvent(type="done")
|
||||||
|
|
||||||
|
|
||||||
def _build_pipeline_with_custom_llm(monkeypatch, llm_service) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
|
def _build_pipeline_with_custom_llm(monkeypatch, llm_service) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
|
||||||
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
|
monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD)
|
||||||
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
|
monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor)
|
||||||
@@ -903,3 +964,72 @@ async def test_fastgpt_provider_managed_tool_timeout_stops_without_generic_tool_
|
|||||||
assert not finals
|
assert not finals
|
||||||
assert llm.generate_stream_calls == 1
|
assert llm.generate_stream_calls == 1
|
||||||
assert llm.resumed_results == []
|
assert llm.resumed_results == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fastgpt_provider_managed_tool_chain_can_continue_after_third_result(monkeypatch):
|
||||||
|
llm = _FakeChainedResumableLLM(["call_fastgpt_1", "call_fastgpt_2", "call_fastgpt_3"])
|
||||||
|
pipeline, events = _build_pipeline_with_custom_llm(monkeypatch, llm)
|
||||||
|
pipeline.apply_runtime_overrides({"output": {"mode": "text"}})
|
||||||
|
|
||||||
|
task = asyncio.create_task(pipeline._handle_turn("start chained fastgpt"))
|
||||||
|
|
||||||
|
expected_call_ids = ["call_fastgpt_1", "call_fastgpt_2", "call_fastgpt_3"]
|
||||||
|
for idx, call_id in enumerate(expected_call_ids, start=1):
|
||||||
|
for _ in range(200):
|
||||||
|
seen_call_ids = [event.get("tool_call_id") for event in events if event.get("type") == "assistant.tool_call"]
|
||||||
|
if call_id in seen_call_ids:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.005)
|
||||||
|
|
||||||
|
await pipeline.handle_tool_call_results(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"name": "fastgpt.interactive",
|
||||||
|
"output": {
|
||||||
|
"action": "submit",
|
||||||
|
"result": {"type": "userInput", "fields": {"result": f"value-{idx}"}},
|
||||||
|
},
|
||||||
|
"status": {"code": 200, "message": "ok"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await task
|
||||||
|
|
||||||
|
finals = [event for event in events if event.get("type") == "assistant.response.final"]
|
||||||
|
assert finals
|
||||||
|
assert "completed after third interactive input" in finals[-1].get("text", "")
|
||||||
|
assert llm.generate_stream_calls == 1
|
||||||
|
assert len(llm.resumed_results) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pending_client_tool_deadline_tracks_waiting_result(monkeypatch):
|
||||||
|
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
|
||||||
|
|
||||||
|
waiter = asyncio.create_task(pipeline._wait_for_single_tool_result("call_deadline", timeout_seconds=30))
|
||||||
|
for _ in range(50):
|
||||||
|
deadline = pipeline.pending_client_tool_deadline()
|
||||||
|
if deadline is not None:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
deadline = pipeline.pending_client_tool_deadline()
|
||||||
|
assert deadline is not None
|
||||||
|
assert deadline > time.monotonic() + 25
|
||||||
|
|
||||||
|
await pipeline.handle_tool_call_results(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_call_id": "call_deadline",
|
||||||
|
"name": "fastgpt.interactive",
|
||||||
|
"output": {"action": "submit", "result": {"type": "userInput", "fields": {"name": "Alice"}}},
|
||||||
|
"status": {"code": 200, "message": "ok"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await waiter
|
||||||
|
|
||||||
|
assert pipeline.pending_client_tool_deadline() is None
|
||||||
|
|||||||
Reference in New Issue
Block a user