LLMService: pass LLM function calls all at once

This commit is contained in:
Aleix Conchillo Flaqué
2025-04-28 13:43:28 -07:00
parent 52569bcdb2
commit 1eb50ad88f
8 changed files with 134 additions and 127 deletions

View File

@@ -45,7 +45,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import LLMService
from pipecat.services.llm_service import FunctionCallLLM, LLMService
from pipecat.utils.tracing.service_decorators import traced_llm
try:
@@ -202,15 +202,8 @@ class AnthropicLLMService(LLMService):
tool_use_block = None
json_accumulator = ""
total_func_calls = 0
function_calls = []
async for event in response:
if event.type == "content_block_start" and event.content_block.type == "tool_use":
total_func_calls += 1
current_func_call = 0
async for event in response:
# logger.debug(f"Anthropic LLM event: {event}")
# Aggregate streaming content, create frames, trigger events
if event.type == "content_block_delta":
@@ -232,15 +225,15 @@ class AnthropicLLMService(LLMService):
and event.delta.stop_reason == "tool_use"
):
if tool_use_block:
run_llm = current_func_call == total_func_calls - 1
await self.call_function(
context=context,
tool_call_id=tool_use_block.id,
function_name=tool_use_block.name,
arguments=json.loads(json_accumulator) if json_accumulator else dict(),
run_llm=run_llm,
args = json.loads(json_accumulator) if json_accumulator else {}
function_calls.append(
FunctionCallLLM(
context=context,
tool_call_id=tool_use_block.id,
function_name=tool_use_block.name,
arguments=args,
)
)
current_func_call += 1
# Calculate usage. Do this here in its own if statement, because there may be usage
# data embedded in messages that we do other processing for, above.
@@ -286,6 +279,8 @@ class AnthropicLLMService(LLMService):
if total_input_tokens >= 1024:
context.turns_above_cache_threshold += 1
await self.run_function_calls(function_calls)
except asyncio.CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
# token estimate. The reraise the exception so all the processors running in this task

View File

@@ -21,6 +21,7 @@ from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.frames.frames import (
Frame,
FunctionCallCancelFrame,
FunctionCallFromLLM,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
@@ -708,6 +709,7 @@ class AWSBedrockLLMService(LLMService):
tool_use_block = None
json_accumulator = ""
function_calls = []
for event in response["stream"]:
# Handle text content
if "contentBlockDelta" in event:
@@ -740,11 +742,13 @@ class AWSBedrockLLMService(LLMService):
# Only call function if it's not the no_operation tool
if not using_noop_tool:
await self.call_function(
context=context,
tool_call_id=tool_use_block["id"],
function_name=tool_use_block["name"],
arguments=arguments,
function_calls.append(
FunctionCallFromLLM(
context=context,
tool_call_id=tool_use_block["id"],
function_name=tool_use_block["name"],
arguments=arguments,
)
)
else:
logger.debug("Ignoring no_operation tool call")
@@ -758,7 +762,7 @@ class AWSBedrockLLMService(LLMService):
completion_tokens += usage.get("outputTokens", 0)
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
await self.run_function_calls(function_calls)
except asyncio.CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
# token estimate. The reraise the exception so all the processors running in this task

View File

@@ -52,7 +52,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import LLMService
from pipecat.services.llm_service import FunctionCallLLM, LLMService
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
@@ -891,16 +891,18 @@ class GeminiMultimodalLiveLLMService(LLMService):
return
if not self._context:
logger.error("Function calls are not supported without a context object.")
total_items = len(function_calls)
for index, call in enumerate(function_calls):
run_llm = index == total_items - 1
await self.call_function(
function_calls_llm = [
FunctionCallLLM(
context=self._context,
tool_call_id=call.id,
function_name=call.name,
arguments=call.args,
run_llm=run_llm,
tool_call_id=f.id,
function_name=f.name,
arguments=f.args,
)
for f in function_calls
]
await self.run_function_calls(function_calls_llm)
@traced_gemini_live(operation="llm_response")
async def _handle_evt_turn_complete(self, evt):

View File

@@ -42,7 +42,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.google.frames import LLMSearchResponseFrame
from pipecat.services.llm_service import LLMService
from pipecat.services.llm_service import FunctionCallLLM, LLMService
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
@@ -557,6 +557,7 @@ class GoogleLLMService(LLMService):
)
await self.stop_ttfb_metrics()
function_calls = []
async for chunk in response:
if chunk.usage_metadata:
prompt_tokens += chunk.usage_metadata.prompt_token_count or 0
@@ -576,11 +577,13 @@ class GoogleLLMService(LLMService):
function_call = part.function_call
id = function_call.id or str(uuid.uuid4())
logger.debug(f"Function call: {function_call.name}:{id}")
await self.call_function(
context=context,
tool_call_id=id,
function_name=function_call.name,
arguments=function_call.args or {},
function_calls.append(
FunctionCallLLM(
context=context,
tool_call_id=id,
function_name=function_call.name,
arguments=function_call.args or {},
)
)
if (
@@ -621,6 +624,8 @@ class GoogleLLMService(LLMService):
"rendered_content": rendered_content,
"origins": origins,
}
await self.run_function_calls(function_calls)
except DeadlineExceeded:
await self._call_event_handler("on_completion_timeout")
except Exception as e:

View File

@@ -10,6 +10,8 @@ import os
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk
from pipecat.services.llm_service import FunctionCallLLM
# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
@@ -18,7 +20,6 @@ from loguru import logger
from pipecat.frames.frames import LLMTextFrame
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.base_llm import OpenAIUnhandledFunctionException
from pipecat.services.openai.llm import OpenAILLMService
@@ -113,26 +114,25 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}"
)
total_func_calls = len(functions_list)
for index, (function_name, arguments, tool_id) in enumerate(
zip(functions_list, arguments_list, tool_id_list)
function_calls = []
for function_name, arguments, tool_id in zip(
functions_list, arguments_list, tool_id_list
):
if function_name == "":
# TODO: Remove the _process_context method once Google resolves the bug
# where the index is incorrectly set to None instead of returning the actual index,
# which currently results in an empty function name('').
continue
if self.has_function(function_name):
arguments = json.loads(arguments)
run_llm = index == total_func_calls - 1
await self.call_function(
arguments = json.loads(arguments)
function_calls.append(
FunctionCallLLM(
context=context,
tool_call_id=tool_id,
function_name=function_name,
arguments=arguments,
tool_call_id=tool_id,
run_llm=run_llm,
)
else:
raise OpenAIUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)
)
await self.run_function_calls(function_calls)

View File

@@ -7,7 +7,7 @@
import asyncio
import inspect
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Protocol, Type
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Protocol, Sequence, Type
from loguru import logger
@@ -45,7 +45,7 @@ class FunctionCallResultCallback(Protocol):
@dataclass
class FunctionCallItem:
class FunctionCallRegistryItem:
"""Represents an entry of our function call registry.
Attributes:
@@ -61,9 +61,27 @@ class FunctionCallItem:
@dataclass
class FunctionCallRunnerItem:
"""Represents a function call entry for our function call runner. The runner
executes function calls in order.
class FunctionCallLLM:
"""Represents a function call returned by the LLM to be registered for execution.
Attributes:
function_name (str): The name of the function.
tool_call_id (str): A unique identifier for the function call.
arguments (Mapping[str, Any]): The arguments for the function.
context (OpenAILLMContext): The LLM context.
"""
function_name: str
tool_call_id: str
arguments: Mapping[str, Any]
context: OpenAILLMContext
@dataclass
class FunctionCallRunner:
"""Represents an internal function call entry to our function call
runner. The runner executes function calls in order.
Attributes:
registry_name (Optional[str]): The function call name registration (could be None).
@@ -74,7 +92,7 @@ class FunctionCallRunnerItem:
"""
registry_item: FunctionCallItem
registry_item: FunctionCallRegistryItem
function_name: str
tool_call_id: str
arguments: Mapping[str, Any]
@@ -115,7 +133,7 @@ class LLMService(AIService):
super().__init__(**kwargs)
self._start_callbacks = {}
self._adapter = self.adapter_class()
self._functions: Dict[Optional[str], FunctionCallItem] = {}
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
self._function_call_runner_task: Optional[asyncio.Task] = None
self._register_event_handler("on_completion_timeout")
@@ -167,7 +185,7 @@ class LLMService(AIService):
):
# Registering a function with the function_name set to None will run
# that handler for all functions
self._functions[function_name] = FunctionCallItem(
self._functions[function_name] = FunctionCallRegistryItem(
function_name=function_name,
handler=handler,
cancel_on_interruption=cancel_on_interruption,
@@ -196,32 +214,32 @@ class LLMService(AIService):
return True
return function_name in self._functions.keys()
async def call_function(
self,
*,
context: OpenAILLMContext,
tool_call_id: str,
function_name: str,
arguments: Mapping[str, Any],
run_llm: bool = True,
):
if function_name in self._functions.keys():
item = self._functions[function_name]
elif None in self._functions.keys():
item = self._functions[None]
else:
return
async def run_function_calls(self, function_calls: Sequence[FunctionCallLLM]):
total_function_calls = len(function_calls)
for index, function_call in enumerate(function_calls):
if function_call.function_name in self._functions.keys():
item = self._functions[function_call.function_name]
elif None in self._functions.keys():
item = self._functions[None]
else:
logger.warning(
f"{self} is calling '{function_call.function_name}', but it's not registered."
)
continue
runner_item = FunctionCallRunnerItem(
registry_item=item,
function_name=function_name,
tool_call_id=tool_call_id,
arguments=arguments,
context=context,
run_llm=run_llm,
)
# Run inference on the last function call.
run_llm = index == total_function_calls - 1
await self._function_call_runner_queue.put(runner_item)
runner_item = FunctionCallRunner(
registry_item=item,
function_name=function_call.function_name,
tool_call_id=function_call.tool_call_id,
arguments=function_call.arguments,
context=function_call.context,
run_llm=run_llm,
)
await self._function_call_runner_queue.put(runner_item)
async def call_start_function(self, context: OpenAILLMContext, function_name: str):
if function_name in self._start_callbacks.keys():
@@ -251,7 +269,7 @@ class LLMService(AIService):
async def _create_runner_task(self):
if not self._function_call_runner_task:
self._current_runner: Optional[FunctionCallRunnerItem] = None
self._current_runner: Optional[FunctionCallRunner] = None
self._current_task: Optional[asyncio.Task] = None
self._function_call_runner_queue = asyncio.Queue()
self._function_call_runner_task = self.create_task(self._function_call_runner_handler())
@@ -269,7 +287,7 @@ class LLMService(AIService):
self._current_runner = None
self._current_task = None
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
async def _run_function_call(self, runner_item: FunctionCallRunner):
if runner_item.function_name in self._functions.keys():
item = self._functions[runner_item.function_name]
elif None in self._functions.keys():

View File

@@ -34,14 +34,10 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import LLMService
from pipecat.services.llm_service import FunctionCallLLM, LLMService
from pipecat.utils.tracing.service_decorators import traced_llm
class OpenAIUnhandledFunctionException(Exception):
pass
class BaseOpenAILLMService(LLMService):
"""This is the base for all services that use the AsyncOpenAI client.
@@ -260,24 +256,22 @@ class BaseOpenAILLMService(LLMService):
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)
total_func_calls = len(functions_list)
for index, (function_name, arguments, tool_id) in enumerate(
zip(functions_list, arguments_list, tool_id_list)
function_calls = []
for function_name, arguments, tool_id in zip(
functions_list, arguments_list, tool_id_list
):
if self.has_function(function_name):
run_llm = index == total_func_calls - 1
arguments = json.loads(arguments)
await self.call_function(
arguments = json.loads(arguments)
function_calls.append(
FunctionCallLLM(
context=context,
tool_call_id=tool_id,
function_name=function_name,
arguments=arguments,
tool_call_id=tool_id,
run_llm=run_llm,
)
else:
raise OpenAIUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)
)
await self.run_function_calls(function_calls)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

View File

@@ -48,7 +48,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import LLMService
from pipecat.services.llm_service import FunctionCallLLM, LLMService
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
@@ -78,10 +78,6 @@ class CurrentAudioResponse:
total_size: int = 0
class OpenAIUnhandledFunctionException(Exception):
pass
class OpenAIRealtimeBetaLLMService(LLMService):
# Overriding the default adapter to use the OpenAIRealtimeLLMAdapter one.
adapter_class = OpenAIRealtimeLLMAdapter
@@ -587,25 +583,18 @@ class OpenAIRealtimeBetaLLMService(LLMService):
await self._handle_function_call_items(function_calls)
async def _handle_function_call_items(self, items):
total_items = len(items)
for index, item in enumerate(items):
function_name = item.name
tool_id = item.call_id
arguments = json.loads(item.arguments)
if self.has_function(function_name):
run_llm = index == total_items - 1
if function_name in self._functions.keys() or None in self._functions.keys():
await self.call_function(
context=self._context,
tool_call_id=tool_id,
function_name=function_name,
arguments=arguments,
run_llm=run_llm,
)
else:
raise OpenAIUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
function_calls = []
for item in items:
args = json.loads(item.arguments)
function_calls.append(
FunctionCallLLM(
context=self._context,
tool_call_id=item.call_id,
function_name=item.name,
arguments=args,
)
)
await self.run_function_calls(function_calls)
#
# state and client events for the current conversation