From 3724ecd378eaf91550251f040cc00c1fbc820b28 Mon Sep 17 00:00:00 2001 From: filipi87 Date: Thu, 2 Apr 2026 16:58:19 -0300 Subject: [PATCH] Supporting async function calls. --- .../function-calling-anthropic.py | 21 ++++-- .../function-calling-openai.py | 15 +++- src/pipecat/frames/frames.py | 7 ++ .../aggregators/llm_response_universal.py | 75 +++++++++++++++---- src/pipecat/services/llm_service.py | 32 +++++++- .../context/llm_context_summarization.py | 64 +++++++++++++++- 6 files changed, 184 insertions(+), 30 deletions(-) diff --git a/examples/function-calling/function-calling-anthropic.py b/examples/function-calling/function-calling-anthropic.py index b8bb4eac6..09cafab8e 100644 --- a/examples/function-calling/function-calling-anthropic.py +++ b/examples/function-calling/function-calling-anthropic.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # - +import asyncio import os from dotenv import load_dotenv @@ -35,9 +35,10 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams load_dotenv(override=True) -async def get_weather(params: FunctionCallParams): - location = params.arguments["location"] - await params.result_callback(f"The weather in {location} is currently 72 degrees and sunny.") +async def fetch_weather_from_api(params: FunctionCallParams): + # Simulate a long-running API call, so we can test async function calls. + await asyncio.sleep(20) + await params.result_callback({"conditions": "nice", "temperature": "75"}) async def fetch_restaurant_recommendation(params: FunctionCallParams): @@ -80,11 +81,19 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.", ), ) - llm.register_function("get_weather", get_weather) + + # You can also register a function_name of None to get all functions + # sent to the same callback with an additional function_name parameter. + llm.register_function( + "get_current_weather", + fetch_weather_from_api, + cancel_on_interruption=False, + timeout_secs=30, + ) llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation) weather_function = FunctionSchema( - name="get_weather", + name="get_current_weather", description="Get the current weather", properties={ "location": { diff --git a/examples/function-calling/function-calling-openai.py b/examples/function-calling/function-calling-openai.py index 2b59d7072..b5c5b83ac 100644 --- a/examples/function-calling/function-calling-openai.py +++ b/examples/function-calling/function-calling-openai.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import asyncio import os from dotenv import load_dotenv @@ -12,7 +13,10 @@ from loguru import logger from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.frames.frames import LLMRunFrame, TTSSpeakFrame +from pipecat.frames.frames import ( + LLMRunFrame, + TTSSpeakFrame, +) from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask @@ -35,6 +39,8 @@ load_dotenv(override=True) async def fetch_weather_from_api(params: FunctionCallParams): + # Simulate a long-running API call, so we can test async function calls. + await asyncio.sleep(20) await params.result_callback({"conditions": "nice", "temperature": "75"}) @@ -88,7 +94,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): # You can also register a function_name of None to get all functions # sent to the same callback with an additional function_name parameter. - llm.register_function("get_current_weather", fetch_weather_from_api) + llm.register_function( + "get_current_weather", + fetch_weather_from_api, + cancel_on_interruption=False, + timeout_secs=30, + ) llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation) @llm.event_handler("on_function_calls_started") diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 86a93825b..27796ff62 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -1642,12 +1642,19 @@ class FunctionCallInProgressFrame(ControlFrame, UninterruptibleFrame): tool_call_id: Unique identifier for this function call. arguments: Arguments passed to the function. cancel_on_interruption: Whether to cancel this call if interrupted. + When ``False`` the call is treated as asynchronous: the LLM + continues the conversation immediately without waiting for the + result, and the result is injected later via a developer message. + group_id: Identifier shared by all function calls originating from the + same LLM response batch. Used to determine when the last call in a + group completes so the LLM can be triggered exactly once. """ function_name: str tool_call_id: str arguments: Any cancel_on_interruption: bool = False + group_id: Optional[str] = None @dataclass diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index f8c2ffd43..b99b251d4 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -866,6 +866,8 @@ class LLMAssistantAggregator(LLMContextAggregator): self._function_calls_image_results: Dict[str, UserImageRawFrame] = {} self._context_updated_tasks: Set[asyncio.Task] = set() + self._user_speaking: bool = False + self._assistant_turn_start_timestamp = "" self._thought_append_to_context = False @@ -968,6 +970,12 @@ class LLMAssistantAggregator(LLMContextAggregator): await self._handle_user_image_frame(frame) elif isinstance(frame, AssistantImageRawFrame): await self._handle_assistant_image_frame(frame) + elif isinstance(frame, UserStartedSpeakingFrame): + self._user_speaking = True + await self.push_frame(frame, direction) + elif isinstance(frame, UserStoppedSpeakingFrame): + self._user_speaking = False + await self.push_frame(frame, direction) else: await self.push_frame(frame, direction) @@ -1047,13 +1055,24 @@ class LLMAssistantAggregator(LLMContextAggregator): ], } ) - self._context.add_message( - { - "role": "tool", - "content": "IN_PROGRESS", - "tool_call_id": frame.tool_call_id, - } - ) + + is_async = not frame.cancel_on_interruption + if is_async: + self._context.add_message( + { + "role": "tool", + "content": json.dumps({"type": "async_tool", "status": "started"}), + "tool_call_id": frame.tool_call_id, + } + ) + else: + self._context.add_message( + { + "role": "tool", + "content": "IN_PROGRESS", + "tool_call_id": frame.tool_call_id, + } + ) self._function_calls_in_progress[frame.tool_call_id] = frame @@ -1067,16 +1086,34 @@ class LLMAssistantAggregator(LLMContextAggregator): ) return + in_progress_frame = self._function_calls_in_progress[frame.tool_call_id] + is_async = not in_progress_frame.cancel_on_interruption if in_progress_frame else False + group_id = in_progress_frame.group_id if in_progress_frame else None + del self._function_calls_in_progress[frame.tool_call_id] properties = frame.properties - # Update context with the function call result - if frame.result: - result = json.dumps(frame.result, ensure_ascii=False) - self._update_function_call_result(frame.function_name, frame.tool_call_id, result) + result = json.dumps(frame.result, ensure_ascii=False) if frame.result else "COMPLETED" + + if is_async: + # For async function calls instead of updating the existing IN_PROGRESS tool message we inject + # a new developer message so the LLM is notified of the completed result. + self._context.add_message( + { + "role": "developer", + "content": json.dumps( + { + "type": "async_tool", + "tool_call_id": frame.tool_call_id, + "status": "finished", + "result": result, + } + ), + } + ) else: - self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED") + self._update_function_call_result(frame.function_name, frame.tool_call_id, result) run_llm = False @@ -1098,10 +1135,18 @@ class LLMAssistantAggregator(LLMContextAggregator): # If the frame is indicating we should run the LLM, do it. run_llm = frame.run_llm else: - # If this is the last function call in progress, run the LLM. - run_llm = not bool(self._function_calls_in_progress) + # Run the LLM when this is the last function call in the group + # to complete. If group_id is set, only consider sibling calls; + # otherwise always execute as soon as we receive the result. + if group_id: + run_llm = not any( + f is not None and f.group_id == group_id + for f in self._function_calls_in_progress.values() + ) + else: + run_llm = True - if run_llm: + if run_llm and not self._user_speaking: await self.push_context_frame(FrameDirection.UPSTREAM) # Call the `on_context_updated` callback once the function call result diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 37d7d094a..1430aceac 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -7,8 +7,8 @@ """Base classes for Large Language Model services with function calling support.""" import asyncio -import inspect import json +import uuid import warnings from dataclasses import dataclass from typing import ( @@ -119,6 +119,9 @@ class FunctionCallRegistryItem: function_name: The name of the function (None for catch-all handler). handler: The handler for processing function call parameters. cancel_on_interruption: Whether to cancel the call on interruption. + When ``False`` the call is treated as asynchronous: the LLM + continues the conversation immediately without waiting for the + result, and the result is injected later via a developer message. timeout_secs: Optional per-tool timeout in seconds. Overrides the global ``function_call_timeout_secs`` for this specific function. """ @@ -142,6 +145,9 @@ class FunctionCallRunnerItem: arguments: The arguments for the function. context: The LLM context. run_llm: Optional flag to control LLM execution after function call. + group_id: Shared identifier for all function calls from the same LLM + response batch. Used to trigger the LLM exactly once when the last + call in the group completes. """ registry_item: FunctionCallRegistryItem @@ -150,6 +156,7 @@ class FunctionCallRunnerItem: arguments: Mapping[str, Any] context: LLMContext run_llm: Optional[bool] = None + group_id: Optional[str] = None class LLMService(UserTurnCompletionLLMServiceMixin, AIService): @@ -185,6 +192,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): def __init__( self, run_in_parallel: bool = True, + group_parallel_tools: bool = True, function_call_timeout_secs: Optional[float] = None, settings: Optional[LLMSettings] = None, **kwargs, @@ -194,6 +202,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): Args: run_in_parallel: Whether to run function calls in parallel or sequentially. Defaults to True. + group_parallel_tools: Whether to group parallel function calls so the LLM + is triggered exactly once after all calls in the batch complete. When + False, each function call result triggers the LLM independently as it + arrives. Defaults to True. function_call_timeout_secs: Optional timeout in seconds for deferred function calls. settings: The runtime-updatable settings for the LLM service. @@ -208,6 +220,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): **kwargs, ) self._run_in_parallel = run_in_parallel + self._group_parallel_tools = group_parallel_tools self._function_call_timeout_secs = function_call_timeout_secs self._filter_incomplete_user_turns: bool = False self._base_system_instruction: Optional[str] = None @@ -548,7 +561,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): handler: The function handler. Should accept a single FunctionCallParams parameter. cancel_on_interruption: Whether to cancel this function call when an - interruption occurs. Defaults to True. + interruption occurs. When ``False`` the call is treated as + asynchronous: the LLM continues the conversation immediately + without waiting for the result, and the result is injected later + via a developer message. Defaults to True. timeout_secs: Optional per-tool timeout in seconds. Overrides the global ``function_call_timeout_secs`` for this specific function. Defaults to None, which uses the global timeout. @@ -578,7 +594,10 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): Args: handler: The direct function to register. Must follow DirectFunction protocol. cancel_on_interruption: Whether to cancel this function call when an - interruption occurs. Defaults to True. + interruption occurs. When ``False`` the call is treated as + asynchronous: the LLM continues the conversation immediately + without waiting for the result, and the result is injected later + via a developer message. Defaults to True. timeout_secs: Optional per-tool timeout in seconds. Overrides the global ``function_call_timeout_secs`` for this specific function. Defaults to None, which uses the global timeout. @@ -639,6 +658,11 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls) + # When group_parallel_tools is True all calls share a group_id so the + # aggregator triggers the LLM exactly once after the last one completes. + # When False, group_id is None and each result triggers inference independently. + group_id = str(uuid.uuid4()) if self._group_parallel_tools else None + runner_items = [] for function_call in function_calls: if function_call.function_name in self._functions.keys(): @@ -658,6 +682,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): tool_call_id=function_call.tool_call_id, arguments=function_call.arguments, context=function_call.context, + group_id=group_id, ) ) @@ -726,6 +751,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): tool_call_id=runner_item.tool_call_id, arguments=runner_item.arguments, cancel_on_interruption=item.cancel_on_interruption, + group_id=runner_item.group_id, ) timeout_task: Optional[asyncio.Task] = None diff --git a/src/pipecat/utils/context/llm_context_summarization.py b/src/pipecat/utils/context/llm_context_summarization.py index 259d0bae5..afb0ecd1f 100644 --- a/src/pipecat/utils/context/llm_context_summarization.py +++ b/src/pipecat/utils/context/llm_context_summarization.py @@ -10,6 +10,7 @@ This module provides reusable functionality for automatically compressing conver context when token limits are reached, enabling efficient long-running conversations. """ +import json import warnings from dataclasses import dataclass, field from typing import TYPE_CHECKING, List, Optional @@ -381,6 +382,35 @@ class LLMContextSummarizationUtil: return total + @staticmethod + def _is_tool_message_pending(content: str) -> bool: + """Return True if a tool message content represents an unresolved call. + + A tool message is considered pending (unresolved) when its content is + the synchronous ``"IN_PROGRESS"`` sentinel or the async + ``{"type": "async_tool", "status": "started"}`` marker — both indicate + that the actual result has not yet been written back to the context. + + Args: + content: The ``content`` field of a tool-role context message. + + Returns: + True if the tool call should be treated as still in progress. + """ + if content == "IN_PROGRESS": + return True + try: + parsed = json.loads(content) + if ( + isinstance(parsed, dict) + and parsed.get("type") == "async_tool" + and parsed.get("status") == "started" + ): + return True + except (json.JSONDecodeError, ValueError): + pass + return False + @staticmethod def _get_earliest_function_call_not_resolved_in_range( messages: List[dict], start_idx: int, summary_end: int @@ -389,9 +419,13 @@ class LLMContextSummarizationUtil: Scans messages from ``start_idx`` up to (but not including) ``summary_end`` to identify tool calls whose responses either don't - exist yet or fall in the kept portion of the context (>= summary_end). + exist yet, fall in the kept portion of the context (>= summary_end), + or are still marked as ``IN_PROGRESS`` (async calls whose results have + not yet arrived). + This prevents summarizing tool call requests when their responses would - remain in the kept context as orphans, which the OpenAI API rejects. + remain in the kept context as orphans, which the OpenAI API rejects, + and avoids summarizing async function calls before their results arrive. Args: messages: List of messages to check. @@ -428,11 +462,33 @@ class LLMContextSummarizationUtil: if tool_call_id: pending_tool_calls[tool_call_id] = i - # Check for tool results + # Check for tool results — treat IN_PROGRESS and async "started" + # messages as still pending so they are not summarized away before + # their results arrive. if role == "tool": tool_call_id = msg.get("tool_call_id") if tool_call_id and tool_call_id in pending_tool_calls: - pending_tool_calls.pop(tool_call_id) + if not LLMContextSummarizationUtil._is_tool_message_pending( + msg.get("content", "") + ): + pending_tool_calls.pop(tool_call_id) + + # Check for async tool completion — a developer message with + # {"type": "async_tool", "status": "finished"} signals that the + # async result has arrived and the call is now resolved. + if role == "developer": + try: + parsed = json.loads(msg.get("content", "")) + if ( + isinstance(parsed, dict) + and parsed.get("type") == "async_tool" + and parsed.get("status") == "finished" + ): + tool_call_id = parsed.get("tool_call_id") + if tool_call_id and tool_call_id in pending_tool_calls: + pending_tool_calls.pop(tool_call_id) + except (json.JSONDecodeError, ValueError): + pass # If we have pending tool calls, return the earliest index if pending_tool_calls: