Ensure that the function call results respect the previous LLM context.
This commit is contained in:
@@ -66,6 +66,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a race condition where, if the LLM received instructions to both produce
|
||||
text and invoke a function call at the same time, the context would not be
|
||||
updated before the function call result arrived, causing the bot to repeat
|
||||
itself.
|
||||
|
||||
- Fixed an issue in the `Runner` where, when using `SmallWebRTCTransport`, the
|
||||
`request_data` was not being passed to the `SmallWebRTCRunnerArguments` body.
|
||||
|
||||
|
||||
@@ -594,6 +594,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
self._started = 0
|
||||
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
|
||||
self._context_updated_tasks: Set[asyncio.Task] = set()
|
||||
self._function_calls_context_messages = []
|
||||
self._function_calls_pending_context_updates_callbacks = []
|
||||
|
||||
@property
|
||||
def has_function_calls_in_progress(self) -> bool:
|
||||
@@ -650,21 +652,23 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push the current assistant aggregation with timestamp."""
|
||||
if not self._aggregation:
|
||||
return
|
||||
if self._aggregation:
|
||||
aggregation = self.aggregation_string()
|
||||
await self.reset()
|
||||
|
||||
aggregation = self.aggregation_string()
|
||||
await self.reset()
|
||||
if aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
# Push context frame
|
||||
await self.push_context_frame()
|
||||
|
||||
# Push context frame
|
||||
await self.push_context_frame()
|
||||
# Push timestamp frame with current time
|
||||
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
await self.push_frame(timestamp_frame)
|
||||
|
||||
# Push timestamp frame with current time
|
||||
timestamp_frame = LLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
await self.push_frame(timestamp_frame)
|
||||
if self._function_calls_context_messages:
|
||||
self._flush_function_call_messages_to_context()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_run(self, frame: LLMRunFrame):
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
@@ -684,6 +688,23 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
def _flush_function_call_messages_to_context(self):
|
||||
"""Move all function calls messages into context, then clear the list."""
|
||||
if self._function_calls_context_messages:
|
||||
self._context.add_messages(self._function_calls_context_messages)
|
||||
self._function_calls_context_messages.clear()
|
||||
|
||||
# Call the `on_context_updated` callbacks once the function call results
|
||||
# are added to the context. Run them in separate tasks to make
|
||||
# sure we don't block the pipeline.
|
||||
for callback, task_name in self._function_calls_pending_context_updates_callbacks:
|
||||
task = self.create_task(callback(), task_name)
|
||||
self._context_updated_tasks.add(task)
|
||||
task.add_done_callback(self._context_updated_task_finished)
|
||||
|
||||
# Clear the pending callbacks list
|
||||
self._function_calls_pending_context_updates_callbacks.clear()
|
||||
|
||||
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
|
||||
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
|
||||
logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}")
|
||||
@@ -696,7 +717,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
)
|
||||
|
||||
# Update context with the in-progress function call
|
||||
self._context.add_message(
|
||||
self._function_calls_context_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
@@ -711,7 +732,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
self._function_calls_context_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
@@ -742,6 +763,13 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
else:
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED")
|
||||
|
||||
# Store the on_context_updated callback along with task name info to be invoked later
|
||||
if properties and properties.on_context_updated:
|
||||
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
|
||||
self._function_calls_pending_context_updates_callbacks.append(
|
||||
(properties.on_context_updated, task_name)
|
||||
)
|
||||
|
||||
run_llm = False
|
||||
|
||||
# Run inference if the function call result requires it.
|
||||
@@ -756,17 +784,13 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
# If this is the last function call in progress, run the LLM.
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
if run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Call the `on_context_updated` callback once the function call result
|
||||
# is added to the context. Also, run this in a separate task to make
|
||||
# sure we don't block the pipeline.
|
||||
if properties and properties.on_context_updated:
|
||||
task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated"
|
||||
task = self.create_task(properties.on_context_updated(), task_name)
|
||||
self._context_updated_tasks.add(task)
|
||||
task.add_done_callback(self._context_updated_task_finished)
|
||||
# Only run if the LLM response has completed (not currently generating),
|
||||
# otherwise defer execution until push_aggregation() is called
|
||||
# (triggered by LLMFullResponseEndFrame or interruption).
|
||||
if not self._started:
|
||||
self._flush_function_call_messages_to_context()
|
||||
if run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
logger.debug(
|
||||
@@ -781,7 +805,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any):
|
||||
for message in self._context.get_messages():
|
||||
for message in self._function_calls_context_messages:
|
||||
if (
|
||||
not isinstance(message, LLMSpecificMessage)
|
||||
and message["role"] == "tool"
|
||||
|
||||
Reference in New Issue
Block a user