Merge pull request #2574 from pipecat-ai/pk/expand-universal-llm-context-support-to-anthropic
Expand universal `LLMContext` support to Anthropic
This commit is contained in:
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Expanded support for universal `LLMContext` to the Anthropic LLM service.
|
||||
Using the universal `LLMContext` and associated `LLMContextAggregatorPair` is
|
||||
a pre-requisite for using `LLMSwitcher` to switch between LLMs at runtime.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `AWSBedrockLLMService` crash caused by an extra `await`.
|
||||
@@ -19,7 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
### Added
|
||||
|
||||
- Added multilingual support for AsyncAI in `AsyncAITTSService` and `AsyncAIHttpTTSService`.
|
||||
- New `languages`: `es`, `fr`, `de`, `it`.
|
||||
|
||||
- New `languages`: `es`, `fr`, `de`, `it`.
|
||||
|
||||
- Added new frames `InputTransportMessageUrgentFrame` and
|
||||
`DailyInputTransportMessageUrgentFrame` for transport messages received from
|
||||
|
||||
@@ -97,7 +97,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model="claude-3-7-sonnet-latest",
|
||||
enable_prompt_caching_beta=True,
|
||||
params=AnthropicLLMService.InputParams(enable_prompt_caching=True),
|
||||
)
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
create_transport,
|
||||
get_transport_client_id,
|
||||
maybe_capture_participant_camera,
|
||||
)
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# Global variable to store the client ID
|
||||
client_id = ""
|
||||
|
||||
|
||||
async def get_weather(params: FunctionCallParams):
|
||||
location = params.arguments["location"]
|
||||
await params.result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def get_image(params: FunctionCallParams):
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={client_id}, question={question}")
|
||||
|
||||
# Request the image frame
|
||||
await params.llm.request_image_frame(
|
||||
user_id=client_id,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
text_content=question,
|
||||
)
|
||||
|
||||
# Wait a short time for the frame to be processed
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Return a result to complete the function call
|
||||
await params.result_callback(
|
||||
f"I've captured an image from your camera and I'm analyzing what you asked about: {question}"
|
||||
)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model="claude-3-7-sonnet-latest",
|
||||
params=AnthropicLLMService.InputParams(enable_prompt_caching=True),
|
||||
)
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
get_image_function = FunctionSchema(
|
||||
name="get_image",
|
||||
description="Get an image from the video stream.",
|
||||
properties={
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question that the user is asking about the image.",
|
||||
}
|
||||
},
|
||||
required=["question"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function, get_image_function])
|
||||
|
||||
system_prompt = """\
|
||||
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
|
||||
|
||||
Your response will be turned into speech so use only simple words and punctuation.
|
||||
|
||||
You have access to two tools: get_weather and get_image.
|
||||
|
||||
You can respond to questions about the weather using the get_weather tool.
|
||||
|
||||
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
|
||||
indicate you should use the get_image tool are:
|
||||
- What do you see?
|
||||
- What's in the video?
|
||||
- Can you describe the video?
|
||||
- Tell me about what you see.
|
||||
- Tell me something interesting about what you see.
|
||||
- What's happening in the video?
|
||||
|
||||
If you need to use a tool, simply use the tool. Do not tell the user the tool you are using. Be brief and concise.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "Start the conversation by introducing yourself."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User speech to text
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
|
||||
await maybe_capture_participant_camera(transport, client)
|
||||
|
||||
global client_id
|
||||
client_id = get_transport_client_id(transport, client)
|
||||
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -39,11 +39,12 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> TLLMInvocationParams:
|
||||
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
|
||||
"""Get provider-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
**kwargs: Additional provider-specific arguments that subclasses can use.
|
||||
|
||||
Returns:
|
||||
Provider-specific parameters for invoking the LLM.
|
||||
|
||||
@@ -6,12 +6,25 @@
|
||||
|
||||
"""Anthropic LLM adapter for Pipecat."""
|
||||
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
from anthropic import NOT_GIVEN, NotGiven
|
||||
from anthropic.types.message_param import MessageParam
|
||||
from anthropic.types.tool_union_param import ToolUnionParam
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
LLMStandardMessage,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicLLMInvocationParams(TypedDict):
|
||||
@@ -20,7 +33,9 @@ class AnthropicLLMInvocationParams(TypedDict):
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
|
||||
"""
|
||||
|
||||
pass
|
||||
system: str | NotGiven
|
||||
messages: List[MessageParam]
|
||||
tools: List[ToolUnionParam] | NotGiven
|
||||
|
||||
|
||||
class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
@@ -30,20 +45,33 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
to the specific format required by Anthropic's Claude models for function calling.
|
||||
"""
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AnthropicLLMInvocationParams:
|
||||
def get_llm_invocation_params(
|
||||
self, context: LLMContext, enable_prompt_caching: bool
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
"""Get Anthropic-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
enable_prompt_caching: Whether prompt caching should be enabled.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for invoking Anthropic's LLM API.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": (
|
||||
self._with_cache_control_markers(messages.messages)
|
||||
if enable_prompt_caching
|
||||
else messages.messages
|
||||
),
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
}
|
||||
|
||||
def get_messages_for_logging(self, context) -> List[Dict[str, Any]]:
|
||||
def get_messages_for_logging(self, context: LLMContext) -> List[Dict[str, Any]]:
|
||||
"""Get messages from a universal LLM context in a format ready for logging about Anthropic.
|
||||
|
||||
Removes or truncates sensitive data like image content for safe logging.
|
||||
@@ -56,7 +84,241 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
Returns:
|
||||
List of messages in a format ready for logging about Anthropic.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
for message in messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image":
|
||||
item["source"]["data"] = "..."
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("anthropic")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
messages: List[MessageParam]
|
||||
system: str | NotGiven
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, universal_context_messages: List[LLMContextMessage]
|
||||
) -> ConvertedMessages:
|
||||
system = NOT_GIVEN
|
||||
messages = []
|
||||
|
||||
# first, map messages using self._from_universal_context_message(m)
|
||||
try:
|
||||
messages = [self._from_universal_context_message(m) for m in universal_context_messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our messages list.
|
||||
if messages and messages[0]["role"] == "system":
|
||||
if len(messages) == 1:
|
||||
# If we have only have a system message in the list, all we can really do
|
||||
# without introducing too much magic is change the role to "user".
|
||||
messages[0]["role"] = "user"
|
||||
else:
|
||||
# If we have more than one message, we'll pull the system message out of the
|
||||
# list.
|
||||
system = messages[0]["content"]
|
||||
messages.pop(0)
|
||||
|
||||
# Convert any subsequent "system"-role messages to "user"-role
|
||||
# messages, as Anthropic doesn't support system input messages.
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
message["role"] = "user"
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(messages) - 1:
|
||||
current_message = messages[i]
|
||||
next_message = messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system=system)
|
||||
|
||||
def _from_universal_context_message(self, message: LLMContextMessage) -> MessageParam:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
return copy.deepcopy(message.message)
|
||||
return self._from_standard_message(message)
|
||||
|
||||
def _from_standard_message(self, message: LLMStandardMessage) -> MessageParam:
|
||||
"""Convert standard universal context message to Anthropic format.
|
||||
|
||||
Handles conversion of text content, tool calls, and tool results.
|
||||
Empty text content is converted to "(empty)".
|
||||
|
||||
Args:
|
||||
message: Message in standard universal context format.
|
||||
|
||||
Returns:
|
||||
Message in Anthropic format.
|
||||
|
||||
Examples:
|
||||
Input standard format::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Output Anthropic format::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "123",
|
||||
"name": "search",
|
||||
"input": {"q": "test"}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
message = copy.deepcopy(message)
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
],
|
||||
}
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"type": "tool_use",
|
||||
"id": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
# fix empty text
|
||||
if content == "":
|
||||
content = "(empty)"
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
# fix empty text
|
||||
if item["type"] == "text" and item["text"] == "":
|
||||
item["text"] = "(empty)"
|
||||
# handle image_url -> image conversion
|
||||
if item["type"] == "image_url":
|
||||
item["type"] = "image"
|
||||
item["source"] = {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": item["image_url"]["url"].split(",")[1],
|
||||
}
|
||||
del item["image_url"]
|
||||
# In the case where there's a single image in the list (like what
|
||||
# would result from a UserImageRawFrame), ensure that the image
|
||||
# comes before text, as recommended by Anthropic docs
|
||||
# (https://docs.anthropic.com/en/docs/build-with-claude/vision#example-one-image)
|
||||
image_indices = [i for i, item in enumerate(content) if item["type"] == "image"]
|
||||
text_indices = [i for i, item in enumerate(content) if item["type"] == "text"]
|
||||
if len(image_indices) == 1 and text_indices:
|
||||
img_idx = image_indices[0]
|
||||
first_txt_idx = text_indices[0]
|
||||
if img_idx > first_txt_idx:
|
||||
# Move image before the first text
|
||||
image_item = content.pop(img_idx)
|
||||
content.insert(first_txt_idx, image_item)
|
||||
|
||||
return message
|
||||
|
||||
def _with_cache_control_markers(self, messages: List[MessageParam]) -> List[MessageParam]:
|
||||
"""Add cache control markers to messages for prompt caching.
|
||||
|
||||
Args:
|
||||
messages: List of messages in Anthropic format.
|
||||
|
||||
Returns:
|
||||
List of messages with cache control markers added.
|
||||
"""
|
||||
|
||||
def add_cache_control_marker(message: MessageParam):
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
message["content"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
try:
|
||||
# Add cache control markers to the most recent two user messages.
|
||||
# - The marker at the most recent user message tells Anthropic to
|
||||
# cache the prompt up to that point.
|
||||
# - The marker at the second-most-recent user message tells Anthropic
|
||||
# to look up the cached prompt that goes up to that point (the
|
||||
# point that *was* the last user message the previous turn).
|
||||
# If we only added the marker to the last user message, we'd only
|
||||
# ever be adding to the cache, never looking up from it.
|
||||
# Why user messages? We're assuming that we're primarily running
|
||||
# inference as soon as user turns come in. In Anthropic, turns
|
||||
# strictly alternate between user and assistant.
|
||||
|
||||
messages_with_markers = copy.deepcopy(messages)
|
||||
|
||||
# Find the most recent two user messages
|
||||
user_message_indices = []
|
||||
for i in range(len(messages_with_markers) - 1, -1, -1):
|
||||
if messages_with_markers[i]["role"] == "user":
|
||||
user_message_indices.append(i)
|
||||
if len(user_message_indices) == 2:
|
||||
break
|
||||
|
||||
# Add cache control markers to the identified user messages
|
||||
for index in user_message_indices:
|
||||
add_cache_control_marker(messages_with_markers[index])
|
||||
|
||||
return messages_with_markers
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding cache control marker: {e}")
|
||||
return messages_with_markers
|
||||
|
||||
@staticmethod
|
||||
def _to_anthropic_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
|
||||
@@ -67,7 +67,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
}
|
||||
|
||||
@@ -192,14 +192,14 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
def _from_standard_message(
|
||||
self, message: LLMStandardMessage, already_have_system_instruction: bool
|
||||
) -> Content | str:
|
||||
"""Convert universal context message to Google Content object.
|
||||
"""Convert standard universal context message to Google Content object.
|
||||
|
||||
Handles conversion of text, images, and function calls to Google's
|
||||
format.
|
||||
System instructions are returned as a plain string.
|
||||
|
||||
Args:
|
||||
message: Message in universal context format.
|
||||
message: Message in standard universal context format.
|
||||
already_have_system_instruction: Whether we already have a system instruction
|
||||
|
||||
Returns:
|
||||
@@ -308,5 +308,4 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
audio_bytes = base64.b64decode(input_audio["data"])
|
||||
parts.append(Part(inline_data=Blob(mime_type="audio/wav", data=audio_bytes)))
|
||||
|
||||
message = Content(role=role, parts=parts)
|
||||
return message
|
||||
return Content(role=role, parts=parts)
|
||||
|
||||
@@ -30,25 +30,17 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
|
||||
"""Get the currently active LLM, if any."""
|
||||
return self.strategy.active_service
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context, using the currently active LLM.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context. If both are provided, the
|
||||
one in the context takes precedence.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
if self.active_llm:
|
||||
return await self.active_llm.run_inference(
|
||||
context=context, system_instruction=system_instruction
|
||||
)
|
||||
return await self.active_llm.run_inference(context=context)
|
||||
return None
|
||||
|
||||
def register_function(
|
||||
|
||||
@@ -24,7 +24,10 @@ from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.adapters.services.anthropic_adapter import (
|
||||
AnthropicLLMAdapter,
|
||||
AnthropicLLMInvocationParams,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
@@ -112,7 +115,12 @@ class AnthropicLLMService(LLMService):
|
||||
"""Input parameters for Anthropic model inference.
|
||||
|
||||
Parameters:
|
||||
enable_prompt_caching_beta: Whether to enable beta prompt caching feature.
|
||||
enable_prompt_caching: Whether to enable the prompt caching feature.
|
||||
enable_prompt_caching_beta (deprecated): Whether to enable the beta prompt caching feature.
|
||||
|
||||
.. deprecated:: 0.0.84
|
||||
Use the `enable_prompt_caching` parameter instead.
|
||||
|
||||
max_tokens: Maximum tokens to generate. Must be at least 1.
|
||||
temperature: Sampling temperature between 0.0 and 1.0.
|
||||
top_k: Top-k sampling parameter.
|
||||
@@ -120,13 +128,26 @@ class AnthropicLLMService(LLMService):
|
||||
extra: Additional parameters to pass to the API.
|
||||
"""
|
||||
|
||||
enable_prompt_caching_beta: Optional[bool] = False
|
||||
enable_prompt_caching: Optional[bool] = None
|
||||
enable_prompt_caching_beta: Optional[bool] = None
|
||||
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
|
||||
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
|
||||
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def model_post_init(self, __context):
|
||||
"""Post-initialization to handle deprecated parameters."""
|
||||
if self.enable_prompt_caching_beta is not None:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"enable_prompt_caching_beta is deprecated. Use enable_prompt_caching instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -159,7 +180,15 @@ class AnthropicLLMService(LLMService):
|
||||
self._retry_on_timeout = retry_on_timeout
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"enable_prompt_caching_beta": params.enable_prompt_caching_beta or False,
|
||||
"enable_prompt_caching": (
|
||||
params.enable_prompt_caching
|
||||
if params.enable_prompt_caching is not None
|
||||
else (
|
||||
params.enable_prompt_caching_beta
|
||||
if params.enable_prompt_caching_beta is not None
|
||||
else False
|
||||
)
|
||||
),
|
||||
"temperature": params.temperature,
|
||||
"top_k": params.top_k,
|
||||
"top_p": params.top_p,
|
||||
@@ -199,34 +228,28 @@ class AnthropicLLMService(LLMService):
|
||||
response = await api_call(**params)
|
||||
return response
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context. If both are provided, the
|
||||
one in the context takes precedence.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
messages = []
|
||||
system = []
|
||||
system = NOT_GIVEN
|
||||
if isinstance(context, LLMContext):
|
||||
# Future code will be something like this:
|
||||
# adapter = self.get_llm_adapter()
|
||||
# params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
# messages = params["messages"]
|
||||
# system = params["system_instruction"]
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
params = adapter.get_llm_invocation_params(
|
||||
context, enable_prompt_caching=self._settings["enable_prompt_caching"]
|
||||
)
|
||||
messages = params["messages"]
|
||||
system = params["system"]
|
||||
else:
|
||||
context = AnthropicLLMContext.upgrade_to_anthropic(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) or system_instruction
|
||||
system = getattr(context, "system", NOT_GIVEN)
|
||||
|
||||
# LLM completion
|
||||
response = await self._client.messages.create(
|
||||
@@ -239,15 +262,6 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
@property
|
||||
def enable_prompt_caching_beta(self) -> bool:
|
||||
"""Check if prompt caching beta feature is enabled.
|
||||
|
||||
Returns:
|
||||
True if prompt caching is enabled.
|
||||
"""
|
||||
return self._enable_prompt_caching_beta
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
@@ -277,8 +291,31 @@ class AnthropicLLMService(LLMService):
|
||||
assistant = AnthropicAssistantContextAggregator(context, params=assistant_params)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def _get_llm_invocation_params(
|
||||
self, context: OpenAILLMContext | LLMContext
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
# Universal LLMContext
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
params = adapter.get_llm_invocation_params(
|
||||
context, enable_prompt_caching=self._settings["enable_prompt_caching"]
|
||||
)
|
||||
return params
|
||||
|
||||
# Anthropic-specific context
|
||||
messages = (
|
||||
context.get_messages_with_cache_control_markers()
|
||||
if self._settings["enable_prompt_caching"]
|
||||
else context.messages
|
||||
)
|
||||
return AnthropicLLMInvocationParams(
|
||||
system=context.system,
|
||||
messages=messages,
|
||||
tools=context.tools,
|
||||
)
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
|
||||
# completion_tokens. We also estimate the completion tokens from output text
|
||||
# and use that estimate if we are interrupted, because we almost certainly won't
|
||||
@@ -294,24 +331,22 @@ class AnthropicLLMService(LLMService):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
params_from_context = self._get_llm_invocation_params(context)
|
||||
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
context_type_for_logging = "universal"
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
else:
|
||||
context_type_for_logging = "LLM-specific"
|
||||
messages_for_logging = context.get_messages_for_logging()
|
||||
logger.debug(
|
||||
f"{self}: Generating chat [{context.system}] | {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from {context_type_for_logging} context [{params_from_context['system']}] | {messages_for_logging}"
|
||||
)
|
||||
|
||||
messages = context.messages
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
messages = context.get_messages_with_cache_control_markers()
|
||||
|
||||
api_call = self._client.messages.create
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
api_call = self._client.beta.prompt_caching.messages.create
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
params = {
|
||||
"tools": context.tools or [],
|
||||
"system": context.system,
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"stream": True,
|
||||
@@ -320,9 +355,12 @@ class AnthropicLLMService(LLMService):
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
# Messages, system, tools
|
||||
params.update(params_from_context)
|
||||
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
response = await self._create_message_stream(api_call, params)
|
||||
response = await self._create_message_stream(self._client.messages.create, params)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -405,7 +443,10 @@ class AnthropicLLMService(LLMService):
|
||||
prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
)
|
||||
if total_input_tokens >= 1024:
|
||||
context.turns_above_cache_threshold += 1
|
||||
if hasattr(
|
||||
context, "turns_above_cache_threshold"
|
||||
): # LLMContext doesn't have this attribute
|
||||
context.turns_above_cache_threshold += 1
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
@@ -451,7 +492,7 @@ class AnthropicLLMService(LLMService):
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = AnthropicLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
@@ -464,7 +505,7 @@ class AnthropicLLMService(LLMService):
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, LLMEnablePromptCachingFrame):
|
||||
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
|
||||
self._settings["enable_prompt_caching_beta"] = frame.enable
|
||||
self._settings["enable_prompt_caching"] = frame.enable
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -792,17 +792,11 @@ class AWSBedrockLLMService(LLMService):
|
||||
"""
|
||||
return True
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context. If both are provided, the
|
||||
one in the context takes precedence.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
@@ -815,14 +809,14 @@ class AWSBedrockLLMService(LLMService):
|
||||
# adapter = self.get_llm_adapter()
|
||||
# params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
# messages = params["messages"]
|
||||
# system = params["system_instruction"]
|
||||
# system = params["system_instruction"] # [{"text": "system message"}]
|
||||
raise NotImplementedError(
|
||||
"Universal LLMContext is not yet supported for AWS Bedrock."
|
||||
)
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) or system_instruction
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
@@ -839,7 +833,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
}
|
||||
|
||||
if system:
|
||||
request_params["system"] = [{"text": system}]
|
||||
request_params["system"] = system
|
||||
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
|
||||
@@ -733,17 +733,11 @@ class GoogleLLMService(LLMService):
|
||||
def _create_client(self, api_key: str, http_options: Optional[HttpOptions] = None):
|
||||
self._client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context. If both are provided, the
|
||||
one in the context takes precedence.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
@@ -758,7 +752,7 @@ class GoogleLLMService(LLMService):
|
||||
else:
|
||||
context = GoogleLLMContext.upgrade_to_google(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system_message", None) or system_instruction
|
||||
system = getattr(context, "system_message", None)
|
||||
|
||||
generation_config = GenerateContentConfig(system_instruction=system)
|
||||
|
||||
@@ -858,8 +852,7 @@ class GoogleLLMService(LLMService):
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncIterator[GenerateContentResponse]:
|
||||
logger.debug(
|
||||
# f"{self}: Generating chat [{self._system_instruction}] | {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from OpenAI context {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from LLM-specific context [{context.system_message}] | {context.get_messages_for_logging()}"
|
||||
)
|
||||
|
||||
params = GeminiLLMInvocationParams(
|
||||
@@ -874,13 +867,12 @@ class GoogleLLMService(LLMService):
|
||||
self, context: LLMContext
|
||||
) -> AsyncIterator[GenerateContentResponse]:
|
||||
adapter = self.get_llm_adapter()
|
||||
logger.debug(
|
||||
# f"{self}: Generating chat [{self._system_instruction}] | {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from universal context {adapter.get_messages_for_logging(context)}"
|
||||
)
|
||||
|
||||
params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from universal context [{params['system_instruction']}] | {adapter.get_messages_for_logging(context)}"
|
||||
)
|
||||
|
||||
return await self._stream_content(params)
|
||||
|
||||
@traced_llm
|
||||
|
||||
@@ -195,18 +195,13 @@ class LLMService(AIService):
|
||||
"""
|
||||
return self._adapter
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
|
||||
@@ -245,16 +245,11 @@ class BaseOpenAILLMService(LLMService):
|
||||
params.update(self._settings["extra"])
|
||||
return params
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
@@ -279,7 +274,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from OpenAI context {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from LLM-specific context {context.get_messages_for_logging()}"
|
||||
)
|
||||
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
|
||||
Reference in New Issue
Block a user