diff --git a/src/pipecat/utils/context/llm_context_summarization.py b/src/pipecat/utils/context/llm_context_summarization.py index 707d0c32d..f7536cbaf 100644 --- a/src/pipecat/utils/context/llm_context_summarization.py +++ b/src/pipecat/utils/context/llm_context_summarization.py @@ -382,25 +382,33 @@ class LLMContextSummarizationUtil: return total @staticmethod - def _get_function_calls_in_progress_index(messages: List[dict], start_idx: int) -> int: + def _get_function_calls_in_progress_index( + messages: List[dict], start_idx: int, summary_end: int + ) -> int: """Find the earliest message index with incomplete function calls. - Scans messages to identify function/tool calls that haven't received - their results yet. This prevents summarizing incomplete tool interactions - which would break the request-response pairing. + Scans messages from ``start_idx`` up to (but not including) + ``summary_end`` to identify tool calls whose responses either don't + exist yet or fall in the kept portion of the context (>= summary_end). + This prevents summarizing tool call requests when their responses would + remain in the kept context as orphans, which the OpenAI API rejects. Args: messages: List of messages to check. start_idx: Index to start checking from. + summary_end: Exclusive upper bound for the scan (the first kept + message index). Only tool responses within this range count as + completing a call; responses beyond it are treated as absent, + leaving the call "in progress". Returns: Index of first message with function call in progress, or -1 if all - function calls are complete. + function calls are complete within the scanned range. """ # Track tool call IDs mapped to their message index pending_tool_calls: dict[str, int] = {} - for i in range(start_idx, len(messages)): + for i in range(start_idx, summary_end): msg = messages[i] # LLMSpecificMessage instances (e.g. thinking blocks) never carry tool_call or # tool_call_id fields, so they cannot affect the pending-call tracking. Skipping @@ -484,7 +492,7 @@ class LLMContextSummarizationUtil: # Check for function calls in progress in the range we want to summarize function_call_start = LLMContextSummarizationUtil._get_function_calls_in_progress_index( - messages, summary_start + messages, summary_start, summary_end ) if function_call_start >= 0 and function_call_start < summary_end: # Stop summarization before the function call diff --git a/tests/test_context_summarization.py b/tests/test_context_summarization.py index e2666e7fa..454535bc0 100644 --- a/tests/test_context_summarization.py +++ b/tests/test_context_summarization.py @@ -954,6 +954,152 @@ class TestDedicatedLLMSummarization(unittest.IsolatedAsyncioTestCase): await summarizer.cleanup() +class TestOrphanedToolResponseDetection(unittest.TestCase): + """Tests that tool responses in the kept range are treated as orphans. + + The scan in _get_function_calls_in_progress_index is bounded by summary_end, + so a tool response that falls in the kept portion (>= summary_end) never + resolves its matching tool call. This ensures the assistant+tool_calls + message and all its responses stay together in the kept range. + """ + + def test_tool_response_in_kept_range_is_treated_as_orphan(self): + """Tool response in the kept range causes the tool call to be kept too.""" + context = LLMContext() + context.add_message({"role": "system", "content": "System prompt"}) # idx 0 + context.add_message({"role": "user", "content": "Hello"}) # idx 1 + context.add_message( # idx 2: assistant with tool_call + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + } + ], + } + ) + context.add_message( + {"role": "tool", "tool_call_id": "call_1", "content": "result"} + ) # idx 3 (kept) + context.add_message({"role": "user", "content": "Thanks"}) # idx 4 (kept) + + # Keep 2: summary_end=3. The tool response at idx 3 is outside the scan + # range → call_1 stays pending → boundary moves back to idx 2. + result = LLMContextSummarizationUtil.get_messages_to_summarize(context, 2) + self.assertEqual(result.last_summarized_index, 1) + self.assertEqual(result.messages[-1]["content"], "Hello") + + def test_tool_response_in_summarized_range_is_not_orphan(self): + """Tool response within the summarized range correctly resolves its call.""" + context = LLMContext() + context.add_message({"role": "system", "content": "System prompt"}) # idx 0 + context.add_message({"role": "user", "content": "Hello"}) # idx 1 + context.add_message( # idx 2: assistant with tool_call + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + } + ], + } + ) + context.add_message( + {"role": "tool", "tool_call_id": "call_1", "content": "result"} + ) # idx 3 + context.add_message({"role": "assistant", "content": "Done"}) # idx 4 + context.add_message({"role": "user", "content": "Thanks"}) # idx 5 (kept) + + # Keep 1: summary_end=5. Both the tool call (idx 2) and its response + # (idx 3) are within the scan range → resolved → no adjustment. + result = LLMContextSummarizationUtil.get_messages_to_summarize(context, 1) + self.assertEqual(result.last_summarized_index, 4) + self.assertEqual(len(result.messages), 4) + + def test_partial_responses_in_kept_range_moves_back(self): + """When only some tool responses are in the kept range the whole group is kept.""" + context = LLMContext() + context.add_message({"role": "system", "content": "System prompt"}) # idx 0 + context.add_message({"role": "user", "content": "Hello"}) # idx 1 + context.add_message( # idx 2: assistant with two tool_calls + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_a", + "type": "function", + "function": {"name": "fn_a", "arguments": "{}"}, + }, + { + "id": "call_b", + "type": "function", + "function": {"name": "fn_b", "arguments": "{}"}, + }, + ], + } + ) + context.add_message( + {"role": "tool", "tool_call_id": "call_a", "content": "result_a"} + ) # idx 3 + context.add_message( + {"role": "tool", "tool_call_id": "call_b", "content": "result_b"} + ) # idx 4 (kept) + context.add_message({"role": "user", "content": "Thanks"}) # idx 5 (kept) + + # Keep 2: summary_end=4. call_a is resolved (idx 3 is in scan range) but + # call_b's response (idx 4) is outside → call_b stays pending → + # function_call_start=2 → boundary moves back to idx 2. + result = LLMContextSummarizationUtil.get_messages_to_summarize(context, 2) + self.assertEqual(result.last_summarized_index, 1) + self.assertEqual(result.messages[-1]["content"], "Hello") + + def test_non_adjacent_orphan_in_kept_range_moves_back(self): + """Orphaned tool response deeper in the kept range (not at the boundary) is detected.""" + context = LLMContext() + context.add_message({"role": "system", "content": "System prompt"}) # idx 0 + context.add_message({"role": "user", "content": "Hello"}) # idx 1 + context.add_message( # idx 2: assistant with two tool_calls + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_a", + "type": "function", + "function": {"name": "fn_a", "arguments": "{}"}, + }, + { + "id": "call_b", + "type": "function", + "function": {"name": "fn_b", "arguments": "{}"}, + }, + ], + } + ) + context.add_message( + {"role": "tool", "tool_call_id": "call_a", "content": "result_a"} + ) # idx 3 + context.add_message({"role": "user", "content": "Intermediate"}) # idx 4 (kept) + context.add_message( + {"role": "tool", "tool_call_id": "call_b", "content": "result_b"} + ) # idx 5 (kept) — NOT adjacent to the boundary + context.add_message({"role": "user", "content": "Latest"}) # idx 6 (kept) + + # Keep 3: summary_end=4. call_b's response is at idx 5, two hops into + # the kept range. The scan stops at idx 4, so call_b is never resolved → + # function_call_start=2 → boundary moves back to idx 2. + result = LLMContextSummarizationUtil.get_messages_to_summarize(context, 3) + self.assertEqual(result.last_summarized_index, 1) + self.assertEqual(result.messages[-1]["content"], "Hello") + + class TestLLMSpecificMessageHandling(unittest.TestCase): """Tests that LLMSpecificMessage objects are correctly skipped in summarization.""" @@ -1022,7 +1168,9 @@ class TestLLMSpecificMessageHandling(unittest.TestCase): {"role": "tool", "tool_call_id": "call_123", "content": '{"time": "10:30 AM"}'}, ] - result = LLMContextSummarizationUtil._get_function_calls_in_progress_index(messages, 0) + result = LLMContextSummarizationUtil._get_function_calls_in_progress_index( + messages, 0, len(messages) + ) self.assertEqual(result, -1)