Fix context summarization leaving orphaned tool responses in kept context.
This commit is contained in:
@@ -382,25 +382,33 @@ class LLMContextSummarizationUtil:
|
||||
return total
|
||||
|
||||
@staticmethod
|
||||
def _get_function_calls_in_progress_index(messages: List[dict], start_idx: int) -> int:
|
||||
def _get_function_calls_in_progress_index(
|
||||
messages: List[dict], start_idx: int, summary_end: int
|
||||
) -> int:
|
||||
"""Find the earliest message index with incomplete function calls.
|
||||
|
||||
Scans messages to identify function/tool calls that haven't received
|
||||
their results yet. This prevents summarizing incomplete tool interactions
|
||||
which would break the request-response pairing.
|
||||
Scans messages from ``start_idx`` up to (but not including)
|
||||
``summary_end`` to identify tool calls whose responses either don't
|
||||
exist yet or fall in the kept portion of the context (>= summary_end).
|
||||
This prevents summarizing tool call requests when their responses would
|
||||
remain in the kept context as orphans, which the OpenAI API rejects.
|
||||
|
||||
Args:
|
||||
messages: List of messages to check.
|
||||
start_idx: Index to start checking from.
|
||||
summary_end: Exclusive upper bound for the scan (the first kept
|
||||
message index). Only tool responses within this range count as
|
||||
completing a call; responses beyond it are treated as absent,
|
||||
leaving the call "in progress".
|
||||
|
||||
Returns:
|
||||
Index of first message with function call in progress, or -1 if all
|
||||
function calls are complete.
|
||||
function calls are complete within the scanned range.
|
||||
"""
|
||||
# Track tool call IDs mapped to their message index
|
||||
pending_tool_calls: dict[str, int] = {}
|
||||
|
||||
for i in range(start_idx, len(messages)):
|
||||
for i in range(start_idx, summary_end):
|
||||
msg = messages[i]
|
||||
# LLMSpecificMessage instances (e.g. thinking blocks) never carry tool_call or
|
||||
# tool_call_id fields, so they cannot affect the pending-call tracking. Skipping
|
||||
@@ -484,7 +492,7 @@ class LLMContextSummarizationUtil:
|
||||
|
||||
# Check for function calls in progress in the range we want to summarize
|
||||
function_call_start = LLMContextSummarizationUtil._get_function_calls_in_progress_index(
|
||||
messages, summary_start
|
||||
messages, summary_start, summary_end
|
||||
)
|
||||
if function_call_start >= 0 and function_call_start < summary_end:
|
||||
# Stop summarization before the function call
|
||||
|
||||
@@ -954,6 +954,152 @@ class TestDedicatedLLMSummarization(unittest.IsolatedAsyncioTestCase):
|
||||
await summarizer.cleanup()
|
||||
|
||||
|
||||
class TestOrphanedToolResponseDetection(unittest.TestCase):
|
||||
"""Tests that tool responses in the kept range are treated as orphans.
|
||||
|
||||
The scan in _get_function_calls_in_progress_index is bounded by summary_end,
|
||||
so a tool response that falls in the kept portion (>= summary_end) never
|
||||
resolves its matching tool call. This ensures the assistant+tool_calls
|
||||
message and all its responses stay together in the kept range.
|
||||
"""
|
||||
|
||||
def test_tool_response_in_kept_range_is_treated_as_orphan(self):
|
||||
"""Tool response in the kept range causes the tool call to be kept too."""
|
||||
context = LLMContext()
|
||||
context.add_message({"role": "system", "content": "System prompt"}) # idx 0
|
||||
context.add_message({"role": "user", "content": "Hello"}) # idx 1
|
||||
context.add_message( # idx 2: assistant with tool_call
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
context.add_message(
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"}
|
||||
) # idx 3 (kept)
|
||||
context.add_message({"role": "user", "content": "Thanks"}) # idx 4 (kept)
|
||||
|
||||
# Keep 2: summary_end=3. The tool response at idx 3 is outside the scan
|
||||
# range → call_1 stays pending → boundary moves back to idx 2.
|
||||
result = LLMContextSummarizationUtil.get_messages_to_summarize(context, 2)
|
||||
self.assertEqual(result.last_summarized_index, 1)
|
||||
self.assertEqual(result.messages[-1]["content"], "Hello")
|
||||
|
||||
def test_tool_response_in_summarized_range_is_not_orphan(self):
|
||||
"""Tool response within the summarized range correctly resolves its call."""
|
||||
context = LLMContext()
|
||||
context.add_message({"role": "system", "content": "System prompt"}) # idx 0
|
||||
context.add_message({"role": "user", "content": "Hello"}) # idx 1
|
||||
context.add_message( # idx 2: assistant with tool_call
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
context.add_message(
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"}
|
||||
) # idx 3
|
||||
context.add_message({"role": "assistant", "content": "Done"}) # idx 4
|
||||
context.add_message({"role": "user", "content": "Thanks"}) # idx 5 (kept)
|
||||
|
||||
# Keep 1: summary_end=5. Both the tool call (idx 2) and its response
|
||||
# (idx 3) are within the scan range → resolved → no adjustment.
|
||||
result = LLMContextSummarizationUtil.get_messages_to_summarize(context, 1)
|
||||
self.assertEqual(result.last_summarized_index, 4)
|
||||
self.assertEqual(len(result.messages), 4)
|
||||
|
||||
def test_partial_responses_in_kept_range_moves_back(self):
|
||||
"""When only some tool responses are in the kept range the whole group is kept."""
|
||||
context = LLMContext()
|
||||
context.add_message({"role": "system", "content": "System prompt"}) # idx 0
|
||||
context.add_message({"role": "user", "content": "Hello"}) # idx 1
|
||||
context.add_message( # idx 2: assistant with two tool_calls
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_a",
|
||||
"type": "function",
|
||||
"function": {"name": "fn_a", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_b",
|
||||
"type": "function",
|
||||
"function": {"name": "fn_b", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
context.add_message(
|
||||
{"role": "tool", "tool_call_id": "call_a", "content": "result_a"}
|
||||
) # idx 3
|
||||
context.add_message(
|
||||
{"role": "tool", "tool_call_id": "call_b", "content": "result_b"}
|
||||
) # idx 4 (kept)
|
||||
context.add_message({"role": "user", "content": "Thanks"}) # idx 5 (kept)
|
||||
|
||||
# Keep 2: summary_end=4. call_a is resolved (idx 3 is in scan range) but
|
||||
# call_b's response (idx 4) is outside → call_b stays pending →
|
||||
# function_call_start=2 → boundary moves back to idx 2.
|
||||
result = LLMContextSummarizationUtil.get_messages_to_summarize(context, 2)
|
||||
self.assertEqual(result.last_summarized_index, 1)
|
||||
self.assertEqual(result.messages[-1]["content"], "Hello")
|
||||
|
||||
def test_non_adjacent_orphan_in_kept_range_moves_back(self):
|
||||
"""Orphaned tool response deeper in the kept range (not at the boundary) is detected."""
|
||||
context = LLMContext()
|
||||
context.add_message({"role": "system", "content": "System prompt"}) # idx 0
|
||||
context.add_message({"role": "user", "content": "Hello"}) # idx 1
|
||||
context.add_message( # idx 2: assistant with two tool_calls
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_a",
|
||||
"type": "function",
|
||||
"function": {"name": "fn_a", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_b",
|
||||
"type": "function",
|
||||
"function": {"name": "fn_b", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
context.add_message(
|
||||
{"role": "tool", "tool_call_id": "call_a", "content": "result_a"}
|
||||
) # idx 3
|
||||
context.add_message({"role": "user", "content": "Intermediate"}) # idx 4 (kept)
|
||||
context.add_message(
|
||||
{"role": "tool", "tool_call_id": "call_b", "content": "result_b"}
|
||||
) # idx 5 (kept) — NOT adjacent to the boundary
|
||||
context.add_message({"role": "user", "content": "Latest"}) # idx 6 (kept)
|
||||
|
||||
# Keep 3: summary_end=4. call_b's response is at idx 5, two hops into
|
||||
# the kept range. The scan stops at idx 4, so call_b is never resolved →
|
||||
# function_call_start=2 → boundary moves back to idx 2.
|
||||
result = LLMContextSummarizationUtil.get_messages_to_summarize(context, 3)
|
||||
self.assertEqual(result.last_summarized_index, 1)
|
||||
self.assertEqual(result.messages[-1]["content"], "Hello")
|
||||
|
||||
|
||||
class TestLLMSpecificMessageHandling(unittest.TestCase):
|
||||
"""Tests that LLMSpecificMessage objects are correctly skipped in summarization."""
|
||||
|
||||
@@ -1022,7 +1168,9 @@ class TestLLMSpecificMessageHandling(unittest.TestCase):
|
||||
{"role": "tool", "tool_call_id": "call_123", "content": '{"time": "10:30 AM"}'},
|
||||
]
|
||||
|
||||
result = LLMContextSummarizationUtil._get_function_calls_in_progress_index(messages, 0)
|
||||
result = LLMContextSummarizationUtil._get_function_calls_in_progress_index(
|
||||
messages, 0, len(messages)
|
||||
)
|
||||
self.assertEqual(result, -1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user