diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d5bb1bb1..715f8c7c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Function calls can now be executed sequentially (in the order received in the completion) by passing `run_in_parallel=False` when creating your LLM - service. By default, function calls run in parallel, so if the LLM completion - returns 2 or more function calls they run concurrently. In both cases, - concurrently and sequentailly, a new LLM completion will run when the last - function call finishes. + service. By default, if the LLM completion returns 2 or more function calls + they run concurrently. In both cases, concurrently and sequentially, a new LLM + completion will run when the last function call finishes. - Added OpenTelemetry tracing for `GeminiMultimodalLiveLLMService` and `OpenAIRealtimeBetaLLMService`. diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 4574d624d..94dad4d1a 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -591,8 +591,6 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): ) return - in_progress = self._function_calls_in_progress[frame.tool_call_id] - del self._function_calls_in_progress[frame.tool_call_id] properties = frame.properties diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 2a94598fa..3f2eefc30 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -97,7 +97,7 @@ class FunctionCallRunner: tool_call_id: str arguments: Mapping[str, Any] context: OpenAILLMContext - run_llm: bool + run_llm: Optional[bool] = None @dataclass @@ -129,12 +129,14 @@ class LLMService(AIService): # However, subclasses should override this with a more specific adapter when necessary. adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter - def __init__(self, **kwargs): + def __init__(self, run_in_parallel: bool = True, **kwargs): super().__init__(**kwargs) + self._run_in_parallel = run_in_parallel self._start_callbacks = {} self._adapter = self.adapter_class() self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {} - self._function_call_runner_task: Optional[asyncio.Task] = None + self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunner] = {} + self._sequential_runner_task: Optional[asyncio.Task] = None self._register_event_handler("on_completion_timeout") @@ -152,17 +154,18 @@ class LLMService(AIService): async def start(self, frame: StartFrame): await super().start(frame) - await self._create_runner_task() + if not self._run_in_parallel: + await self._create_sequential_runner_task() async def stop(self, frame: EndFrame): await super().stop(frame) - await self._cancel_function_call() - await self._cancel_runner_task() + if not self._run_in_parallel: + await self._cancel_sequential_runner_task() async def cancel(self, frame: CancelFrame): await super().cancel(frame) - await self._cancel_function_call() - await self._cancel_runner_task() + if not self._run_in_parallel: + await self._cancel_sequential_runner_task() async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -171,9 +174,9 @@ class LLMService(AIService): await self._handle_interruptions(frame) async def _handle_interruptions(self, _: StartInterruptionFrame): - await self._cancel_function_call() - await self._cancel_runner_task() - await self._create_runner_task() + for function_name, entry in self._functions.items(): + if entry.cancel_on_interruption: + await self._cancel_function_call(function_name) def register_function( self, @@ -227,8 +230,12 @@ class LLMService(AIService): ) continue - # Run inference on the last function call. - run_llm = index == total_function_calls - 1 + # If we are not running in parallel, run inference on the last + # function call. Otherwise, the last function call to finish is the + # one that will run the inference. + run_llm = None + if not self._run_in_parallel: + run_llm = index == total_function_calls - 1 runner_item = FunctionCallRunner( registry_item=item, @@ -239,7 +246,12 @@ class LLMService(AIService): run_llm=run_llm, ) - await self._function_call_runner_queue.put(runner_item) + if self._run_in_parallel: + task = self.create_task(self._run_function_call(runner_item)) + self._function_call_tasks[task] = runner_item + task.add_done_callback(self._function_call_task_finished) + else: + await self._sequential_runner_queue.put(runner_item) async def call_start_function(self, context: OpenAILLMContext, function_name: str): if function_name in self._start_callbacks.keys(): @@ -267,25 +279,25 @@ class LLMService(AIService): FrameDirection.UPSTREAM, ) - async def _create_runner_task(self): - if not self._function_call_runner_task: - self._current_runner: Optional[FunctionCallRunner] = None - self._current_task: Optional[asyncio.Task] = None - self._function_call_runner_queue = asyncio.Queue() - self._function_call_runner_task = self.create_task(self._function_call_runner_handler()) + async def _create_sequential_runner_task(self): + if not self._sequential_runner_task: + self._sequential_runner_queue = asyncio.Queue() + self._sequential_runner_task = self.create_task(self._sequential_runner_handler()) - async def _cancel_runner_task(self): - if self._function_call_runner_task: - await self.cancel_task(self._function_call_runner_task) - self._function_call_runner_task = None + async def _cancel_sequential_runner_task(self): + if self._sequential_runner_task: + await self.cancel_task(self._sequential_runner_task) + self._sequential_runner_task = None - async def _function_call_runner_handler(self): + async def _sequential_runner_handler(self): while True: - self._current_runner = await self._function_call_runner_queue.get() - self._current_task = self.create_task(self._run_function_call(self._current_runner)) - await self.wait_for_task(self._current_task) - self._current_runner = None - self._current_task = None + runner_item = await self._sequential_runner_queue.get() + task = self.create_task(self._run_function_call(runner_item)) + self._function_call_tasks[task] = runner_item + # Since we run tasks sequentially we don't need to call + # task.add_done_callback(self._function_call_task_finished). + await self.wait_for_task(task) + del self._function_call_tasks[task] async def _run_function_call(self, runner_item: FunctionCallRunner): if runner_item.function_name in self._functions.keys(): @@ -378,20 +390,37 @@ class LLMService(AIService): ) await item.handler(params) - async def _cancel_function_call(self): - if ( - self._current_runner - and self._current_task - and self._current_runner.registry_item.cancel_on_interruption - ): - name = self._current_runner.function_name - tool_call_id = self._current_runner.tool_call_id + async def _cancel_function_call(self, function_name: Optional[str]): + cancelled_tasks = set() + for task, runner_item in self._function_call_tasks.items(): + if runner_item.registry_item.function_name == function_name: + name = runner_item.function_name + tool_call_id = runner_item.tool_call_id - logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...") + # We remove the callback because we are going to cancel the task + # now, otherwise we will be removing it from the set while we + # are iterating. + task.remove_done_callback(self._function_call_task_finished) - await self.cancel_task(self._current_task) + logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...") - frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id) - await self.push_frame(frame) + await self.cancel_task(task) - logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled") + frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id) + await self.push_frame(frame) + + cancelled_tasks.add(task) + + logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled") + + # Remove all cancelled tasks from our set. + for task in cancelled_tasks: + self._function_call_task_finished(task) + + def _function_call_task_finished(self, task: asyncio.Task): + if task in self._function_call_tasks: + del self._function_call_tasks[task] + # The task is finished so this should exit immediately. We need to + # do this because otherwise the task manager would report a dangling + # task if we don't remove it. + asyncio.run_coroutine_threadsafe(self.wait_for_task(task), self.get_event_loop())