LLMService: don't allow running functions concurrently for now

This commit is contained in:
Aleix Conchillo Flaqué
2025-04-28 11:49:29 -07:00
parent a50a407415
commit 52569bcdb2
3 changed files with 46 additions and 63 deletions

View File

@@ -675,7 +675,6 @@ class FunctionCallInProgressFrame(SystemFrame):
tool_call_id: str
arguments: Any
cancel_on_interruption: bool = False
run_concurrently: bool = False
@dataclass

View File

@@ -603,13 +603,13 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
if frame.result:
run_llm = False
if properties and properties.run_llm is not None:
# If the tool call result has a run_llm property, use it
# If the tool call result has a run_llm property, use it.
run_llm = properties.run_llm
elif frame.run_llm is not None:
# If the frame is indicating we should run the LLM, do it.
run_llm = frame.run_llm
elif in_progress.run_concurrently:
# If this was a parallel function call and there are no pending function call, run the LLM.
else:
# If this is the last function call in progress, run the LLM.
run_llm = not bool(self._function_calls_in_progress)
if run_llm:

View File

@@ -52,14 +52,12 @@ class FunctionCallItem:
function_name (Optional[str]): The name of the function.
handler (FunctionCallHandler): The handler for processing function call parameters.
cancel_on_interruption (bool): Flag indicating whether to cancel the call on interruption.
run_concurrently (bool): Flag to indicate if this function call should run concurrently or not.
"""
function_name: Optional[str]
handler: FunctionCallHandler
cancel_on_interruption: bool
run_concurrently: bool
@dataclass
@@ -76,7 +74,7 @@ class FunctionCallRunnerItem:
"""
registry_name: Optional[str]
registry_item: FunctionCallItem
function_name: str
tool_call_id: str
arguments: Mapping[str, Any]
@@ -118,7 +116,7 @@ class LLMService(AIService):
self._start_callbacks = {}
self._adapter = self.adapter_class()
self._functions: Dict[Optional[str], FunctionCallItem] = {}
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
self._function_call_runner_task: Optional[asyncio.Task] = None
self._register_event_handler("on_completion_timeout")
@@ -136,18 +134,17 @@ class LLMService(AIService):
async def start(self, frame: StartFrame):
await super().start(frame)
self._function_call_runner_queue = asyncio.Queue()
self._function_call_runner_task = self.create_task(self._function_call_runner_handler())
await self._create_runner_task()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._cancel_function_call(None)
await self.cancel_task(self._function_call_runner_task)
await self._cancel_function_call()
await self._cancel_runner_task()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._cancel_function_call(None)
await self.cancel_task(self._function_call_runner_task)
await self._cancel_function_call()
await self._cancel_runner_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -156,9 +153,9 @@ class LLMService(AIService):
await self._handle_interruptions(frame)
async def _handle_interruptions(self, _: StartInterruptionFrame):
for function_name, entry in self._functions.items():
if entry.cancel_on_interruption:
await self._cancel_function_call(function_name)
await self._cancel_function_call()
await self._cancel_runner_task()
await self._create_runner_task()
def register_function(
self,
@@ -167,7 +164,6 @@ class LLMService(AIService):
start_callback=None,
*,
cancel_on_interruption: bool = False,
run_concurrently: bool = False,
):
# Registering a function with the function_name set to None will run
# that handler for all functions
@@ -175,7 +171,6 @@ class LLMService(AIService):
function_name=function_name,
handler=handler,
cancel_on_interruption=cancel_on_interruption,
run_concurrently=run_concurrently,
)
# Start callbacks are now deprecated.
@@ -212,27 +207,21 @@ class LLMService(AIService):
):
if function_name in self._functions.keys():
item = self._functions[function_name]
registry_name = function_name
elif None in self._functions.keys():
item = self._functions[None]
registry_name = None
else:
return
runner_item = FunctionCallRunnerItem(
registry_name=registry_name,
registry_item=item,
function_name=function_name,
tool_call_id=tool_call_id,
arguments=arguments,
context=context,
run_llm=run_llm,
)
if item.run_concurrently:
task = self.create_task(self._run_function_call(runner_item))
self._function_call_tasks[task] = runner_item
task.add_done_callback(self._function_call_task_finished)
else:
await self._function_call_runner_queue.put(runner_item)
await self._function_call_runner_queue.put(runner_item)
async def call_start_function(self, context: OpenAILLMContext, function_name: str):
if function_name in self._start_callbacks.keys():
@@ -260,11 +249,25 @@ class LLMService(AIService):
FrameDirection.UPSTREAM,
)
async def _create_runner_task(self):
if not self._function_call_runner_task:
self._current_runner: Optional[FunctionCallRunnerItem] = None
self._current_task: Optional[asyncio.Task] = None
self._function_call_runner_queue = asyncio.Queue()
self._function_call_runner_task = self.create_task(self._function_call_runner_handler())
async def _cancel_runner_task(self):
if self._function_call_runner_task:
await self.cancel_task(self._function_call_runner_task)
self._function_call_runner_task = None
async def _function_call_runner_handler(self):
while True:
runner_item = await self._function_call_runner_queue.get()
task = self.create_task(self._run_function_call(runner_item))
await self.wait_for_task(task)
self._current_runner = await self._function_call_runner_queue.get()
self._current_task = self.create_task(self._run_function_call(self._current_runner))
await self.wait_for_task(self._current_task)
self._current_runner = None
self._current_task = None
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
if runner_item.function_name in self._functions.keys():
@@ -291,14 +294,12 @@ class LLMService(AIService):
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
cancel_on_interruption=item.cancel_on_interruption,
run_concurrently=item.run_concurrently,
)
progress_frame_upstream = FunctionCallInProgressFrame(
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
cancel_on_interruption=item.cancel_on_interruption,
run_concurrently=item.run_concurrently,
)
# Push frame both downstream and upstream
@@ -359,37 +360,20 @@ class LLMService(AIService):
)
await item.handler(params)
async def _cancel_function_call(self, function_name: Optional[str]):
cancelled_tasks = set()
for task, runner_item in self._function_call_tasks.items():
if runner_item.registry_name == function_name:
name = runner_item.function_name
tool_call_id = runner_item.tool_call_id
async def _cancel_function_call(self):
if (
self._current_runner
and self._current_task
and self._current_runner.registry_item.cancel_on_interruption
):
name = self._current_runner.function_name
tool_call_id = self._current_runner.tool_call_id
# We remove the callback because we are going to cancel the task
# now, otherwise we will be removing it from the set while we
# are iterating.
task.remove_done_callback(self._function_call_task_finished)
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
await self.cancel_task(self._current_task)
await self.cancel_task(task)
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
await self.push_frame(frame)
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
await self.push_frame(frame)
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
cancelled_tasks.add(task)
# Remove all cancelled tasks from our set.
for task in cancelled_tasks:
self._function_call_task_finished(task)
def _function_call_task_finished(self, task: asyncio.Task):
if task in self._function_call_tasks:
del self._function_call_tasks[task]
# The task is finished so this should exit immediately. We need to
# do this because otherwise the task manager would report a dangling
# task if we don't remove it.
asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")