LLMService: pass LLM function calls all at once
This commit is contained in:
@@ -45,7 +45,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.llm_service import FunctionCallLLM, LLMService
|
||||
from pipecat.utils.tracing.service_decorators import traced_llm
|
||||
|
||||
try:
|
||||
@@ -202,15 +202,8 @@ class AnthropicLLMService(LLMService):
|
||||
tool_use_block = None
|
||||
json_accumulator = ""
|
||||
|
||||
total_func_calls = 0
|
||||
function_calls = []
|
||||
async for event in response:
|
||||
if event.type == "content_block_start" and event.content_block.type == "tool_use":
|
||||
total_func_calls += 1
|
||||
|
||||
current_func_call = 0
|
||||
async for event in response:
|
||||
# logger.debug(f"Anthropic LLM event: {event}")
|
||||
|
||||
# Aggregate streaming content, create frames, trigger events
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
@@ -232,15 +225,15 @@ class AnthropicLLMService(LLMService):
|
||||
and event.delta.stop_reason == "tool_use"
|
||||
):
|
||||
if tool_use_block:
|
||||
run_llm = current_func_call == total_func_calls - 1
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=tool_use_block.id,
|
||||
function_name=tool_use_block.name,
|
||||
arguments=json.loads(json_accumulator) if json_accumulator else dict(),
|
||||
run_llm=run_llm,
|
||||
args = json.loads(json_accumulator) if json_accumulator else {}
|
||||
function_calls.append(
|
||||
FunctionCallLLM(
|
||||
context=context,
|
||||
tool_call_id=tool_use_block.id,
|
||||
function_name=tool_use_block.name,
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
current_func_call += 1
|
||||
|
||||
# Calculate usage. Do this here in its own if statement, because there may be usage
|
||||
# data embedded in messages that we do other processing for, above.
|
||||
@@ -286,6 +279,8 @@ class AnthropicLLMService(LLMService):
|
||||
if total_input_tokens >= 1024:
|
||||
context.turns_above_cache_threshold += 1
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
|
||||
# token estimate. The reraise the exception so all the processors running in this task
|
||||
|
||||
@@ -21,6 +21,7 @@ from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallFromLLM,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -708,6 +709,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
tool_use_block = None
|
||||
json_accumulator = ""
|
||||
|
||||
function_calls = []
|
||||
for event in response["stream"]:
|
||||
# Handle text content
|
||||
if "contentBlockDelta" in event:
|
||||
@@ -740,11 +742,13 @@ class AWSBedrockLLMService(LLMService):
|
||||
|
||||
# Only call function if it's not the no_operation tool
|
||||
if not using_noop_tool:
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=tool_use_block["id"],
|
||||
function_name=tool_use_block["name"],
|
||||
arguments=arguments,
|
||||
function_calls.append(
|
||||
FunctionCallFromLLM(
|
||||
context=context,
|
||||
tool_call_id=tool_use_block["id"],
|
||||
function_name=tool_use_block["name"],
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug("Ignoring no_operation tool call")
|
||||
@@ -758,7 +762,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
completion_tokens += usage.get("outputTokens", 0)
|
||||
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
|
||||
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
except asyncio.CancelledError:
|
||||
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
|
||||
# token estimate. The reraise the exception so all the processors running in this task
|
||||
|
||||
@@ -52,7 +52,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.llm_service import FunctionCallLLM, LLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
@@ -891,16 +891,18 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
return
|
||||
if not self._context:
|
||||
logger.error("Function calls are not supported without a context object.")
|
||||
total_items = len(function_calls)
|
||||
for index, call in enumerate(function_calls):
|
||||
run_llm = index == total_items - 1
|
||||
await self.call_function(
|
||||
|
||||
function_calls_llm = [
|
||||
FunctionCallLLM(
|
||||
context=self._context,
|
||||
tool_call_id=call.id,
|
||||
function_name=call.name,
|
||||
arguments=call.args,
|
||||
run_llm=run_llm,
|
||||
tool_call_id=f.id,
|
||||
function_name=f.name,
|
||||
arguments=f.args,
|
||||
)
|
||||
for f in function_calls
|
||||
]
|
||||
|
||||
await self.run_function_calls(function_calls_llm)
|
||||
|
||||
@traced_gemini_live(operation="llm_response")
|
||||
async def _handle_evt_turn_complete(self, evt):
|
||||
|
||||
@@ -42,7 +42,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.llm_service import FunctionCallLLM, LLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
@@ -557,6 +557,7 @@ class GoogleLLMService(LLMService):
|
||||
)
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
function_calls = []
|
||||
async for chunk in response:
|
||||
if chunk.usage_metadata:
|
||||
prompt_tokens += chunk.usage_metadata.prompt_token_count or 0
|
||||
@@ -576,11 +577,13 @@ class GoogleLLMService(LLMService):
|
||||
function_call = part.function_call
|
||||
id = function_call.id or str(uuid.uuid4())
|
||||
logger.debug(f"Function call: {function_call.name}:{id}")
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=id,
|
||||
function_name=function_call.name,
|
||||
arguments=function_call.args or {},
|
||||
function_calls.append(
|
||||
FunctionCallLLM(
|
||||
context=context,
|
||||
tool_call_id=id,
|
||||
function_name=function_call.name,
|
||||
arguments=function_call.args or {},
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -621,6 +624,8 @@ class GoogleLLMService(LLMService):
|
||||
"rendered_content": rendered_content,
|
||||
"origins": origins,
|
||||
}
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
except DeadlineExceeded:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,6 +10,8 @@ import os
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
from pipecat.services.llm_service import FunctionCallLLM
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
@@ -18,7 +20,6 @@ from loguru import logger
|
||||
from pipecat.frames.frames import LLMTextFrame
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import OpenAIUnhandledFunctionException
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
@@ -113,26 +114,25 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||
f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}"
|
||||
)
|
||||
|
||||
total_func_calls = len(functions_list)
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list)
|
||||
function_calls = []
|
||||
for function_name, arguments, tool_id in zip(
|
||||
functions_list, arguments_list, tool_id_list
|
||||
):
|
||||
if function_name == "":
|
||||
# TODO: Remove the _process_context method once Google resolves the bug
|
||||
# where the index is incorrectly set to None instead of returning the actual index,
|
||||
# which currently results in an empty function name('').
|
||||
continue
|
||||
if self.has_function(function_name):
|
||||
arguments = json.loads(arguments)
|
||||
run_llm = index == total_func_calls - 1
|
||||
await self.call_function(
|
||||
|
||||
arguments = json.loads(arguments)
|
||||
|
||||
function_calls.append(
|
||||
FunctionCallLLM(
|
||||
context=context,
|
||||
tool_call_id=tool_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
)
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Protocol, Type
|
||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional, Protocol, Sequence, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -45,7 +45,7 @@ class FunctionCallResultCallback(Protocol):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallItem:
|
||||
class FunctionCallRegistryItem:
|
||||
"""Represents an entry of our function call registry.
|
||||
|
||||
Attributes:
|
||||
@@ -61,9 +61,27 @@ class FunctionCallItem:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallRunnerItem:
|
||||
"""Represents a function call entry for our function call runner. The runner
|
||||
executes function calls in order.
|
||||
class FunctionCallLLM:
|
||||
"""Represents a function call returned by the LLM to be registered for execution.
|
||||
|
||||
Attributes:
|
||||
function_name (str): The name of the function.
|
||||
tool_call_id (str): A unique identifier for the function call.
|
||||
arguments (Mapping[str, Any]): The arguments for the function.
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
context: OpenAILLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallRunner:
|
||||
"""Represents an internal function call entry to our function call
|
||||
runner. The runner executes function calls in order.
|
||||
|
||||
Attributes:
|
||||
registry_name (Optional[str]): The function call name registration (could be None).
|
||||
@@ -74,7 +92,7 @@ class FunctionCallRunnerItem:
|
||||
|
||||
"""
|
||||
|
||||
registry_item: FunctionCallItem
|
||||
registry_item: FunctionCallRegistryItem
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
@@ -115,7 +133,7 @@ class LLMService(AIService):
|
||||
super().__init__(**kwargs)
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallItem] = {}
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
self._function_call_runner_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
@@ -167,7 +185,7 @@ class LLMService(AIService):
|
||||
):
|
||||
# Registering a function with the function_name set to None will run
|
||||
# that handler for all functions
|
||||
self._functions[function_name] = FunctionCallItem(
|
||||
self._functions[function_name] = FunctionCallRegistryItem(
|
||||
function_name=function_name,
|
||||
handler=handler,
|
||||
cancel_on_interruption=cancel_on_interruption,
|
||||
@@ -196,32 +214,32 @@ class LLMService(AIService):
|
||||
return True
|
||||
return function_name in self._functions.keys()
|
||||
|
||||
async def call_function(
|
||||
self,
|
||||
*,
|
||||
context: OpenAILLMContext,
|
||||
tool_call_id: str,
|
||||
function_name: str,
|
||||
arguments: Mapping[str, Any],
|
||||
run_llm: bool = True,
|
||||
):
|
||||
if function_name in self._functions.keys():
|
||||
item = self._functions[function_name]
|
||||
elif None in self._functions.keys():
|
||||
item = self._functions[None]
|
||||
else:
|
||||
return
|
||||
async def run_function_calls(self, function_calls: Sequence[FunctionCallLLM]):
|
||||
total_function_calls = len(function_calls)
|
||||
for index, function_call in enumerate(function_calls):
|
||||
if function_call.function_name in self._functions.keys():
|
||||
item = self._functions[function_call.function_name]
|
||||
elif None in self._functions.keys():
|
||||
item = self._functions[None]
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self} is calling '{function_call.function_name}', but it's not registered."
|
||||
)
|
||||
continue
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
context=context,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
# Run inference on the last function call.
|
||||
run_llm = index == total_function_calls - 1
|
||||
|
||||
await self._function_call_runner_queue.put(runner_item)
|
||||
runner_item = FunctionCallRunner(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
|
||||
await self._function_call_runner_queue.put(runner_item)
|
||||
|
||||
async def call_start_function(self, context: OpenAILLMContext, function_name: str):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
@@ -251,7 +269,7 @@ class LLMService(AIService):
|
||||
|
||||
async def _create_runner_task(self):
|
||||
if not self._function_call_runner_task:
|
||||
self._current_runner: Optional[FunctionCallRunnerItem] = None
|
||||
self._current_runner: Optional[FunctionCallRunner] = None
|
||||
self._current_task: Optional[asyncio.Task] = None
|
||||
self._function_call_runner_queue = asyncio.Queue()
|
||||
self._function_call_runner_task = self.create_task(self._function_call_runner_handler())
|
||||
@@ -269,7 +287,7 @@ class LLMService(AIService):
|
||||
self._current_runner = None
|
||||
self._current_task = None
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunner):
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
item = self._functions[runner_item.function_name]
|
||||
elif None in self._functions.keys():
|
||||
|
||||
@@ -34,14 +34,10 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.llm_service import FunctionCallLLM, LLMService
|
||||
from pipecat.utils.tracing.service_decorators import traced_llm
|
||||
|
||||
|
||||
class OpenAIUnhandledFunctionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseOpenAILLMService(LLMService):
|
||||
"""This is the base for all services that use the AsyncOpenAI client.
|
||||
|
||||
@@ -260,24 +256,22 @@ class BaseOpenAILLMService(LLMService):
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
total_func_calls = len(functions_list)
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list)
|
||||
function_calls = []
|
||||
|
||||
for function_name, arguments, tool_id in zip(
|
||||
functions_list, arguments_list, tool_id_list
|
||||
):
|
||||
if self.has_function(function_name):
|
||||
run_llm = index == total_func_calls - 1
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
arguments = json.loads(arguments)
|
||||
function_calls.append(
|
||||
FunctionCallLLM(
|
||||
context=context,
|
||||
tool_call_id=tool_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
)
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -48,7 +48,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.llm_service import FunctionCallLLM, LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -78,10 +78,6 @@ class CurrentAudioResponse:
|
||||
total_size: int = 0
|
||||
|
||||
|
||||
class OpenAIUnhandledFunctionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
# Overriding the default adapter to use the OpenAIRealtimeLLMAdapter one.
|
||||
adapter_class = OpenAIRealtimeLLMAdapter
|
||||
@@ -587,25 +583,18 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
await self._handle_function_call_items(function_calls)
|
||||
|
||||
async def _handle_function_call_items(self, items):
|
||||
total_items = len(items)
|
||||
for index, item in enumerate(items):
|
||||
function_name = item.name
|
||||
tool_id = item.call_id
|
||||
arguments = json.loads(item.arguments)
|
||||
if self.has_function(function_name):
|
||||
run_llm = index == total_items - 1
|
||||
if function_name in self._functions.keys() or None in self._functions.keys():
|
||||
await self.call_function(
|
||||
context=self._context,
|
||||
tool_call_id=tool_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
function_calls = []
|
||||
for item in items:
|
||||
args = json.loads(item.arguments)
|
||||
function_calls.append(
|
||||
FunctionCallLLM(
|
||||
context=self._context,
|
||||
tool_call_id=item.call_id,
|
||||
function_name=item.name,
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
#
|
||||
# state and client events for the current conversation
|
||||
|
||||
Reference in New Issue
Block a user