From a510b276e6ba1dbb678f4bd6b3b17fbdc61e409b Mon Sep 17 00:00:00 2001 From: Filipi Fuchter Date: Tue, 18 Nov 2025 11:37:57 -0300 Subject: [PATCH] Ensure that the function call results respect the previous LLM context. --- CHANGELOG.md | 5 ++ .../aggregators/llm_response_universal.py | 74 ++++++++++++------- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de1f520a7..f2fa5801e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed a race condition where, if the LLM received instructions to both produce + text and invoke a function call at the same time, the context would not be + updated before the function call result arrived, causing the bot to repeat + itself. + - Fixed an issue in the `Runner` where, when using `SmallWebRTCTransport`, the `request_data` was not being passed to the `SmallWebRTCRunnerArguments` body. diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index d4d9ad7da..98d7514b6 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -594,6 +594,8 @@ class LLMAssistantAggregator(LLMContextAggregator): self._started = 0 self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {} self._context_updated_tasks: Set[asyncio.Task] = set() + self._function_calls_context_messages = [] + self._function_calls_pending_context_updates_callbacks = [] @property def has_function_calls_in_progress(self) -> bool: @@ -650,21 +652,23 @@ class LLMAssistantAggregator(LLMContextAggregator): async def push_aggregation(self): """Push the current assistant aggregation with timestamp.""" - if not self._aggregation: - return + if self._aggregation: + aggregation = self.aggregation_string() + await self.reset() - aggregation = self.aggregation_string() - await self.reset() + if aggregation: + self._context.add_message({"role": "assistant", "content": aggregation}) - if aggregation: - self._context.add_message({"role": "assistant", "content": aggregation}) + # Push context frame + await self.push_context_frame() - # Push context frame - await self.push_context_frame() + # Push timestamp frame with current time + timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) + await self.push_frame(timestamp_frame) - # Push timestamp frame with current time - timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) - await self.push_frame(timestamp_frame) + if self._function_calls_context_messages: + self._flush_function_call_messages_to_context() + await self.push_context_frame(FrameDirection.UPSTREAM) async def _handle_llm_run(self, frame: LLMRunFrame): await self.push_context_frame(FrameDirection.UPSTREAM) @@ -684,6 +688,23 @@ class LLMAssistantAggregator(LLMContextAggregator): self._started = 0 await self.reset() + def _flush_function_call_messages_to_context(self): + """Move all function calls messages into context, then clear the list.""" + if self._function_calls_context_messages: + self._context.add_messages(self._function_calls_context_messages) + self._function_calls_context_messages.clear() + + # Call the `on_context_updated` callbacks once the function call results + # are added to the context. Run them in separate tasks to make + # sure we don't block the pipeline. + for callback, task_name in self._function_calls_pending_context_updates_callbacks: + task = self.create_task(callback(), task_name) + self._context_updated_tasks.add(task) + task.add_done_callback(self._context_updated_task_finished) + + # Clear the pending callbacks list + self._function_calls_pending_context_updates_callbacks.clear() + async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame): function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls] logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}") @@ -696,7 +717,7 @@ class LLMAssistantAggregator(LLMContextAggregator): ) # Update context with the in-progress function call - self._context.add_message( + self._function_calls_context_messages.append( { "role": "assistant", "tool_calls": [ @@ -711,7 +732,7 @@ class LLMAssistantAggregator(LLMContextAggregator): ], } ) - self._context.add_message( + self._function_calls_context_messages.append( { "role": "tool", "content": "IN_PROGRESS", @@ -742,6 +763,13 @@ class LLMAssistantAggregator(LLMContextAggregator): else: self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED") + # Store the on_context_updated callback along with task name info to be invoked later + if properties and properties.on_context_updated: + task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated" + self._function_calls_pending_context_updates_callbacks.append( + (properties.on_context_updated, task_name) + ) + run_llm = False # Run inference if the function call result requires it. @@ -756,17 +784,13 @@ class LLMAssistantAggregator(LLMContextAggregator): # If this is the last function call in progress, run the LLM. run_llm = not bool(self._function_calls_in_progress) - if run_llm: - await self.push_context_frame(FrameDirection.UPSTREAM) - - # Call the `on_context_updated` callback once the function call result - # is added to the context. Also, run this in a separate task to make - # sure we don't block the pipeline. - if properties and properties.on_context_updated: - task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated" - task = self.create_task(properties.on_context_updated(), task_name) - self._context_updated_tasks.add(task) - task.add_done_callback(self._context_updated_task_finished) + # Only run if the LLM response has completed (not currently generating), + # otherwise defer execution until push_aggregation() is called + # (triggered by LLMFullResponseEndFrame or interruption). + if not self._started: + self._flush_function_call_messages_to_context() + if run_llm: + await self.push_context_frame(FrameDirection.UPSTREAM) async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame): logger.debug( @@ -781,7 +805,7 @@ class LLMAssistantAggregator(LLMContextAggregator): del self._function_calls_in_progress[frame.tool_call_id] def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any): - for message in self._context.get_messages(): + for message in self._function_calls_context_messages: if ( not isinstance(message, LLMSpecificMessage) and message["role"] == "tool"