Creating the concept of a group_id for the function calls.
This commit is contained in:
@@ -1921,6 +1921,9 @@ class FunctionCallInProgressFrame(ControlFrame, UninterruptibleFrame):
|
||||
is_async: Whether this function call runs asynchronously. When True,
|
||||
the LLM continues the conversation immediately without waiting for
|
||||
the result. The result is injected later via a developer message.
|
||||
group_id: Identifier shared by all function calls originating from the
|
||||
same LLM response batch. Used to determine when the last call in a
|
||||
group completes so the LLM can be triggered exactly once.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
@@ -1928,6 +1931,7 @@ class FunctionCallInProgressFrame(ControlFrame, UninterruptibleFrame):
|
||||
arguments: Any
|
||||
cancel_on_interruption: bool = False
|
||||
is_async: bool = False
|
||||
group_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1077,6 +1077,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
in_progress_frame = self._function_calls_in_progress[frame.tool_call_id]
|
||||
is_async = in_progress_frame.is_async if in_progress_frame else False
|
||||
group_id = in_progress_frame.group_id if in_progress_frame else None
|
||||
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
properties = frame.properties
|
||||
@@ -1115,8 +1117,16 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
# If the frame is indicating we should run the LLM, do it.
|
||||
run_llm = frame.run_llm
|
||||
else:
|
||||
# If this is the last function call in progress, run the LLM.
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
# Run the LLM when this is the last function call in the group
|
||||
# to complete. If group_id is set, only consider sibling calls;
|
||||
# otherwise always execute as soon as we receive the result.
|
||||
if group_id:
|
||||
run_llm = not any(
|
||||
f is not None and f.group_id == group_id
|
||||
for f in self._function_calls_in_progress.values()
|
||||
)
|
||||
else:
|
||||
run_llm = True
|
||||
|
||||
if run_llm and not self._user_speaking:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
"""Base classes for Large Language Model services with function calling support."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
@@ -151,6 +151,9 @@ class FunctionCallRunnerItem:
|
||||
arguments: The arguments for the function.
|
||||
context: The LLM context.
|
||||
run_llm: Optional flag to control LLM execution after function call.
|
||||
group_id: Shared identifier for all function calls from the same LLM
|
||||
response batch. Used to trigger the LLM exactly once when the last
|
||||
call in the group completes.
|
||||
"""
|
||||
|
||||
registry_item: FunctionCallRegistryItem
|
||||
@@ -159,6 +162,7 @@ class FunctionCallRunnerItem:
|
||||
arguments: Mapping[str, Any]
|
||||
context: OpenAILLMContext | LLMContext
|
||||
run_llm: Optional[bool] = None
|
||||
group_id: Optional[str] = None
|
||||
|
||||
|
||||
class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
@@ -695,6 +699,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
|
||||
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)
|
||||
|
||||
# All function calls from the same LLM response share a group_id so the
|
||||
# aggregator can trigger the LLM exactly once when the last one completes.
|
||||
group_id = str(uuid.uuid4())
|
||||
|
||||
runner_items = []
|
||||
for function_call in function_calls:
|
||||
if function_call.function_name in self._functions.keys():
|
||||
@@ -714,6 +722,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
group_id=group_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -783,6 +792,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
arguments=runner_item.arguments,
|
||||
cancel_on_interruption=item.cancel_on_interruption,
|
||||
is_async=item.is_async,
|
||||
group_id=runner_item.group_id,
|
||||
)
|
||||
|
||||
timeout_task: Optional[asyncio.Task] = None
|
||||
|
||||
Reference in New Issue
Block a user