Ensure that the function call results respect the previous LLM context.

This commit is contained in:
Filipi Fuchter
2025-11-18 11:37:57 -03:00
parent 6481094638
commit a510b276e6
2 changed files with 54 additions and 25 deletions

View File

@@ -66,6 +66,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed a race condition where, if the LLM received instructions to both produce
text and invoke a function call at the same time, the context would not be
updated before the function call result arrived, causing the bot to repeat
itself.
- Fixed an issue in the `Runner` where, when using `SmallWebRTCTransport`, the
`request_data` was not being passed to the `SmallWebRTCRunnerArguments` body.

View File

@@ -594,6 +594,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
self._started = 0
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
self._context_updated_tasks: Set[asyncio.Task] = set()
self._function_calls_context_messages = []
self._function_calls_pending_context_updates_callbacks = []
@property
def has_function_calls_in_progress(self) -> bool:
@@ -650,21 +652,23 @@ class LLMAssistantAggregator(LLMContextAggregator):
async def push_aggregation(self):
"""Push the current assistant aggregation with timestamp."""
if not self._aggregation:
return
if self._aggregation:
aggregation = self.aggregation_string()
await self.reset()
aggregation = self.aggregation_string()
await self.reset()
if aggregation:
self._context.add_message({"role": "assistant", "content": aggregation})
if aggregation:
self._context.add_message({"role": "assistant", "content": aggregation})
# Push context frame
await self.push_context_frame()
# Push context frame
await self.push_context_frame()
# Push timestamp frame with current time
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
await self.push_frame(timestamp_frame)
# Push timestamp frame with current time
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
await self.push_frame(timestamp_frame)
if self._function_calls_context_messages:
self._flush_function_call_messages_to_context()
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_run(self, frame: LLMRunFrame):
await self.push_context_frame(FrameDirection.UPSTREAM)
@@ -684,6 +688,23 @@ class LLMAssistantAggregator(LLMContextAggregator):
self._started = 0
await self.reset()
def _flush_function_call_messages_to_context(self):
"""Move all function calls messages into context, then clear the list."""
if self._function_calls_context_messages:
self._context.add_messages(self._function_calls_context_messages)
self._function_calls_context_messages.clear()
# Call the `on_context_updated` callbacks once the function call results
# are added to the context. Run them in separate tasks to make
# sure we don't block the pipeline.
for callback, task_name in self._function_calls_pending_context_updates_callbacks:
task = self.create_task(callback(), task_name)
self._context_updated_tasks.add(task)
task.add_done_callback(self._context_updated_task_finished)
# Clear the pending callbacks list
self._function_calls_pending_context_updates_callbacks.clear()
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}")
@@ -696,7 +717,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
)
# Update context with the in-progress function call
self._context.add_message(
self._function_calls_context_messages.append(
{
"role": "assistant",
"tool_calls": [
@@ -711,7 +732,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
],
}
)
self._context.add_message(
self._function_calls_context_messages.append(
{
"role": "tool",
"content": "IN_PROGRESS",
@@ -742,6 +763,13 @@ class LLMAssistantAggregator(LLMContextAggregator):
else:
self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED")
# Store the on_context_updated callback along with task name info to be invoked later
if properties and properties.on_context_updated:
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
self._function_calls_pending_context_updates_callbacks.append(
(properties.on_context_updated, task_name)
)
run_llm = False
# Run inference if the function call result requires it.
@@ -756,17 +784,13 @@ class LLMAssistantAggregator(LLMContextAggregator):
# If this is the last function call in progress, run the LLM.
run_llm = not bool(self._function_calls_in_progress)
if run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
# Call the `on_context_updated` callback once the function call result
# is added to the context. Also, run this in a separate task to make
# sure we don't block the pipeline.
if properties and properties.on_context_updated:
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
task = self.create_task(properties.on_context_updated(), task_name)
self._context_updated_tasks.add(task)
task.add_done_callback(self._context_updated_task_finished)
# Only run if the LLM response has completed (not currently generating),
# otherwise defer execution until push_aggregation() is called
# (triggered by LLMFullResponseEndFrame or interruption).
if not self._started:
self._flush_function_call_messages_to_context()
if run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
logger.debug(
@@ -781,7 +805,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
del self._function_calls_in_progress[frame.tool_call_id]
def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any):
for message in self._context.get_messages():
for message in self._function_calls_context_messages:
if (
not isinstance(message, LLMSpecificMessage)
and message["role"] == "tool"