From 52569bcdb2acc19f393945b34b0fa61b10e050b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 28 Apr 2025 11:49:29 -0700 Subject: [PATCH] LLMService: don't allow running functions concurrently for now --- src/pipecat/frames/frames.py | 1 - .../processors/aggregators/llm_response.py | 6 +- src/pipecat/services/llm_service.py | 102 ++++++++---------- 3 files changed, 46 insertions(+), 63 deletions(-) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 0c2cec5bd..bb6143c61 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -675,7 +675,6 @@ class FunctionCallInProgressFrame(SystemFrame): tool_call_id: str arguments: Any cancel_on_interruption: bool = False - run_concurrently: bool = False @dataclass diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index f4a7dcb9c..4574d624d 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -603,13 +603,13 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): if frame.result: run_llm = False if properties and properties.run_llm is not None: - # If the tool call result has a run_llm property, use it + # If the tool call result has a run_llm property, use it. run_llm = properties.run_llm elif frame.run_llm is not None: # If the frame is indicating we should run the LLM, do it. run_llm = frame.run_llm - elif in_progress.run_concurrently: - # If this was a parallel function call and there are no pending function call, run the LLM. + else: + # If this is the last function call in progress, run the LLM. run_llm = not bool(self._function_calls_in_progress) if run_llm: diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 767208026..915f69a47 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -52,14 +52,12 @@ class FunctionCallItem: function_name (Optional[str]): The name of the function. handler (FunctionCallHandler): The handler for processing function call parameters. cancel_on_interruption (bool): Flag indicating whether to cancel the call on interruption. - run_concurrently (bool): Flag to indicate if this function call should run concurrently or not. """ function_name: Optional[str] handler: FunctionCallHandler cancel_on_interruption: bool - run_concurrently: bool @dataclass @@ -76,7 +74,7 @@ class FunctionCallRunnerItem: """ - registry_name: Optional[str] + registry_item: FunctionCallItem function_name: str tool_call_id: str arguments: Mapping[str, Any] @@ -118,7 +116,7 @@ class LLMService(AIService): self._start_callbacks = {} self._adapter = self.adapter_class() self._functions: Dict[Optional[str], FunctionCallItem] = {} - self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {} + self._function_call_runner_task: Optional[asyncio.Task] = None self._register_event_handler("on_completion_timeout") @@ -136,18 +134,17 @@ class LLMService(AIService): async def start(self, frame: StartFrame): await super().start(frame) - self._function_call_runner_queue = asyncio.Queue() - self._function_call_runner_task = self.create_task(self._function_call_runner_handler()) + await self._create_runner_task() async def stop(self, frame: EndFrame): await super().stop(frame) - await self._cancel_function_call(None) - await self.cancel_task(self._function_call_runner_task) + await self._cancel_function_call() + await self._cancel_runner_task() async def cancel(self, frame: CancelFrame): await super().cancel(frame) - await self._cancel_function_call(None) - await self.cancel_task(self._function_call_runner_task) + await self._cancel_function_call() + await self._cancel_runner_task() async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -156,9 +153,9 @@ class LLMService(AIService): await self._handle_interruptions(frame) async def _handle_interruptions(self, _: StartInterruptionFrame): - for function_name, entry in self._functions.items(): - if entry.cancel_on_interruption: - await self._cancel_function_call(function_name) + await self._cancel_function_call() + await self._cancel_runner_task() + await self._create_runner_task() def register_function( self, @@ -167,7 +164,6 @@ class LLMService(AIService): start_callback=None, *, cancel_on_interruption: bool = False, - run_concurrently: bool = False, ): # Registering a function with the function_name set to None will run # that handler for all functions @@ -175,7 +171,6 @@ class LLMService(AIService): function_name=function_name, handler=handler, cancel_on_interruption=cancel_on_interruption, - run_concurrently=run_concurrently, ) # Start callbacks are now deprecated. @@ -212,27 +207,21 @@ class LLMService(AIService): ): if function_name in self._functions.keys(): item = self._functions[function_name] - registry_name = function_name elif None in self._functions.keys(): item = self._functions[None] - registry_name = None else: return runner_item = FunctionCallRunnerItem( - registry_name=registry_name, + registry_item=item, function_name=function_name, tool_call_id=tool_call_id, arguments=arguments, context=context, run_llm=run_llm, ) - if item.run_concurrently: - task = self.create_task(self._run_function_call(runner_item)) - self._function_call_tasks[task] = runner_item - task.add_done_callback(self._function_call_task_finished) - else: - await self._function_call_runner_queue.put(runner_item) + + await self._function_call_runner_queue.put(runner_item) async def call_start_function(self, context: OpenAILLMContext, function_name: str): if function_name in self._start_callbacks.keys(): @@ -260,11 +249,25 @@ class LLMService(AIService): FrameDirection.UPSTREAM, ) + async def _create_runner_task(self): + if not self._function_call_runner_task: + self._current_runner: Optional[FunctionCallRunnerItem] = None + self._current_task: Optional[asyncio.Task] = None + self._function_call_runner_queue = asyncio.Queue() + self._function_call_runner_task = self.create_task(self._function_call_runner_handler()) + + async def _cancel_runner_task(self): + if self._function_call_runner_task: + await self.cancel_task(self._function_call_runner_task) + self._function_call_runner_task = None + async def _function_call_runner_handler(self): while True: - runner_item = await self._function_call_runner_queue.get() - task = self.create_task(self._run_function_call(runner_item)) - await self.wait_for_task(task) + self._current_runner = await self._function_call_runner_queue.get() + self._current_task = self.create_task(self._run_function_call(self._current_runner)) + await self.wait_for_task(self._current_task) + self._current_runner = None + self._current_task = None async def _run_function_call(self, runner_item: FunctionCallRunnerItem): if runner_item.function_name in self._functions.keys(): @@ -291,14 +294,12 @@ class LLMService(AIService): tool_call_id=runner_item.tool_call_id, arguments=runner_item.arguments, cancel_on_interruption=item.cancel_on_interruption, - run_concurrently=item.run_concurrently, ) progress_frame_upstream = FunctionCallInProgressFrame( function_name=runner_item.function_name, tool_call_id=runner_item.tool_call_id, arguments=runner_item.arguments, cancel_on_interruption=item.cancel_on_interruption, - run_concurrently=item.run_concurrently, ) # Push frame both downstream and upstream @@ -359,37 +360,20 @@ class LLMService(AIService): ) await item.handler(params) - async def _cancel_function_call(self, function_name: Optional[str]): - cancelled_tasks = set() - for task, runner_item in self._function_call_tasks.items(): - if runner_item.registry_name == function_name: - name = runner_item.function_name - tool_call_id = runner_item.tool_call_id + async def _cancel_function_call(self): + if ( + self._current_runner + and self._current_task + and self._current_runner.registry_item.cancel_on_interruption + ): + name = self._current_runner.function_name + tool_call_id = self._current_runner.tool_call_id - # We remove the callback because we are going to cancel the task - # now, otherwise we will be removing it from the set while we - # are iterating. - task.remove_done_callback(self._function_call_task_finished) + logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...") - logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...") + await self.cancel_task(self._current_task) - await self.cancel_task(task) + frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id) + await self.push_frame(frame) - frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id) - await self.push_frame(frame) - - logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled") - - cancelled_tasks.add(task) - - # Remove all cancelled tasks from our set. - for task in cancelled_tasks: - self._function_call_task_finished(task) - - def _function_call_task_finished(self, task: asyncio.Task): - if task in self._function_call_tasks: - del self._function_call_tasks[task] - # The task is finished so this should exit immediately. We need to - # do this because otherwise the task manager would report a dangling - # task if we don't remove it. - asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop()) + logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")