Merge pull request #2574 from pipecat-ai/pk/expand-universal-llm-context-support-to-anthropic

Expand universal `LLMContext` support to Anthropic
This commit is contained in:
kompfner
2025-09-04 13:09:44 -04:00
committed by GitHub
12 changed files with 597 additions and 108 deletions

View File

@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- Expanded support for universal `LLMContext` to the Anthropic LLM service.
Using the universal `LLMContext` and associated `LLMContextAggregatorPair` is
a pre-requisite for using `LLMSwitcher` to switch between LLMs at runtime.
### Fixed
- Fixed a `AWSBedrockLLMService` crash caused by an extra `await`.
@@ -19,7 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added multilingual support for AsyncAI in `AsyncAITTSService` and `AsyncAIHttpTTSService`.
- New `languages`: `es`, `fr`, `de`, `it`.
- New `languages`: `es`, `fr`, `de`, `it`.
- Added new frames `InputTransportMessageUrgentFrame` and
`DailyInputTransportMessageUrgentFrame` for transport messages received from

View File

@@ -97,7 +97,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
llm = AnthropicLLMService(
api_key=os.getenv("ANTHROPIC_API_KEY"),
model="claude-3-7-sonnet-latest",
enable_prompt_caching_beta=True,
params=AnthropicLLMService.InputParams(enable_prompt_caching=True),
)
llm.register_function("get_weather", get_weather)
llm.register_function("get_image", get_image)

View File

@@ -0,0 +1,211 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import (
create_transport,
get_transport_client_id,
maybe_capture_participant_camera,
)
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.services.daily import DailyParams
load_dotenv(override=True)
# Global variable to store the client ID
client_id = ""
async def get_weather(params: FunctionCallParams):
location = params.arguments["location"]
await params.result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
async def get_image(params: FunctionCallParams):
question = params.arguments["question"]
logger.debug(f"Requesting image with user_id={client_id}, question={question}")
# Request the image frame
await params.llm.request_image_frame(
user_id=client_id,
function_name=params.function_name,
tool_call_id=params.tool_call_id,
text_content=question,
)
# Wait a short time for the frame to be processed
await asyncio.sleep(0.5)
# Return a result to complete the function call
await params.result_callback(
f"I've captured an image from your camera and I'm analyzing what you asked about: {question}"
)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
video_in_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
video_in_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
llm = AnthropicLLMService(
api_key=os.getenv("ANTHROPIC_API_KEY"),
model="claude-3-7-sonnet-latest",
params=AnthropicLLMService.InputParams(enable_prompt_caching=True),
)
llm.register_function("get_weather", get_weather)
llm.register_function("get_image", get_image)
weather_function = FunctionSchema(
name="get_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
required=["location"],
)
get_image_function = FunctionSchema(
name="get_image",
description="Get an image from the video stream.",
properties={
"question": {
"type": "string",
"description": "The question that the user is asking about the image.",
}
},
required=["question"],
)
tools = ToolsSchema(standard_tools=[weather_function, get_image_function])
system_prompt = """\
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
Your response will be turned into speech so use only simple words and punctuation.
You have access to two tools: get_weather and get_image.
You can respond to questions about the weather using the get_weather tool.
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
indicate you should use the get_image tool are:
- What do you see?
- What's in the video?
- Can you describe the video?
- Tell me about what you see.
- Tell me something interesting about what you see.
- What's happening in the video?
If you need to use a tool, simply use the tool. Do not tell the user the tool you are using. Be brief and concise.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Start the conversation by introducing yourself."},
]
context = LLMContext(messages, tools)
context_aggregator = LLMContextAggregatorPair(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # STT
context_aggregator.user(), # User speech to text
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses and tool context
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_camera(transport, client)
global client_id
client_id = get_transport_client_id(transport, client)
# Kick off the conversation.
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -39,11 +39,12 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
"""
@abstractmethod
def get_llm_invocation_params(self, context: LLMContext) -> TLLMInvocationParams:
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
"""Get provider-specific LLM invocation parameters from a universal LLM context.
Args:
context: The LLM context containing messages, tools, etc.
**kwargs: Additional provider-specific arguments that subclasses can use.
Returns:
Provider-specific parameters for invoking the LLM.

View File

@@ -6,12 +6,25 @@
"""Anthropic LLM adapter for Pipecat."""
from typing import Any, Dict, List, TypedDict
import copy
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, TypedDict
from anthropic import NOT_GIVEN, NotGiven
from anthropic.types.message_param import MessageParam
from anthropic.types.tool_union_param import ToolUnionParam
from loguru import logger
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
LLMSpecificMessage,
LLMStandardMessage,
)
class AnthropicLLMInvocationParams(TypedDict):
@@ -20,7 +33,9 @@ class AnthropicLLMInvocationParams(TypedDict):
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
"""
pass
system: str | NotGiven
messages: List[MessageParam]
tools: List[ToolUnionParam] | NotGiven
class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
@@ -30,20 +45,33 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
to the specific format required by Anthropic's Claude models for function calling.
"""
def get_llm_invocation_params(self, context: LLMContext) -> AnthropicLLMInvocationParams:
def get_llm_invocation_params(
self, context: LLMContext, enable_prompt_caching: bool
) -> AnthropicLLMInvocationParams:
"""Get Anthropic-specific LLM invocation parameters from a universal LLM context.
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
Args:
context: The LLM context containing messages, tools, etc.
enable_prompt_caching: Whether prompt caching should be enabled.
Returns:
Dictionary of parameters for invoking Anthropic's LLM API.
"""
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
messages = self._from_universal_context_messages(self._get_messages(context))
return {
"system": messages.system,
"messages": (
self._with_cache_control_markers(messages.messages)
if enable_prompt_caching
else messages.messages
),
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
"tools": self.from_standard_tools(context.tools),
}
def get_messages_for_logging(self, context) -> List[Dict[str, Any]]:
def get_messages_for_logging(self, context: LLMContext) -> List[Dict[str, Any]]:
"""Get messages from a universal LLM context in a format ready for logging about Anthropic.
Removes or truncates sensitive data like image content for safe logging.
@@ -56,7 +84,241 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
Returns:
List of messages in a format ready for logging about Anthropic.
"""
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
# Get messages in Anthropic's format
messages = self._from_universal_context_messages(self._get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
for message in messages:
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
for item in msg["content"]:
if item["type"] == "image":
item["source"]["data"] = "..."
messages_for_logging.append(msg)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("anthropic")
@dataclass
class ConvertedMessages:
"""Container for Anthropic-formatted messages converted from universal context."""
messages: List[MessageParam]
system: str | NotGiven
def _from_universal_context_messages(
self, universal_context_messages: List[LLMContextMessage]
) -> ConvertedMessages:
system = NOT_GIVEN
messages = []
# first, map messages using self._from_universal_context_message(m)
try:
messages = [self._from_universal_context_message(m) for m in universal_context_messages]
except Exception as e:
logger.error(f"Error mapping messages: {e}")
# See if we should pull the system message out of our messages list.
if messages and messages[0]["role"] == "system":
if len(messages) == 1:
# If we have only have a system message in the list, all we can really do
# without introducing too much magic is change the role to "user".
messages[0]["role"] = "user"
else:
# If we have more than one message, we'll pull the system message out of the
# list.
system = messages[0]["content"]
messages.pop(0)
# Convert any subsequent "system"-role messages to "user"-role
# messages, as Anthropic doesn't support system input messages.
for message in messages:
if message["role"] == "system":
message["role"] = "user"
# Merge consecutive messages with the same role.
i = 0
while i < len(messages) - 1:
current_message = messages[i]
next_message = messages[i + 1]
if current_message["role"] == next_message["role"]:
# Convert content to list of dictionaries if it's a string
if isinstance(current_message["content"], str):
current_message["content"] = [
{"type": "text", "text": current_message["content"]}
]
if isinstance(next_message["content"], str):
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
# Concatenate the content
current_message["content"].extend(next_message["content"])
# Remove the next message from the list
messages.pop(i + 1)
else:
i += 1
# Avoid empty content in messages
for message in messages:
if isinstance(message["content"], str) and message["content"] == "":
message["content"] = "(empty)"
elif isinstance(message["content"], list) and len(message["content"]) == 0:
message["content"] = [{"type": "text", "text": "(empty)"}]
return self.ConvertedMessages(messages=messages, system=system)
def _from_universal_context_message(self, message: LLMContextMessage) -> MessageParam:
if isinstance(message, LLMSpecificMessage):
return copy.deepcopy(message.message)
return self._from_standard_message(message)
def _from_standard_message(self, message: LLMStandardMessage) -> MessageParam:
"""Convert standard universal context message to Anthropic format.
Handles conversion of text content, tool calls, and tool results.
Empty text content is converted to "(empty)".
Args:
message: Message in standard universal context format.
Returns:
Message in Anthropic format.
Examples:
Input standard format::
{
"role": "assistant",
"tool_calls": [
{
"id": "123",
"function": {"name": "search", "arguments": '{"q": "test"}'}
}
]
}
Output Anthropic format::
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "123",
"name": "search",
"input": {"q": "test"}
}
]
}
"""
message = copy.deepcopy(message)
if message["role"] == "tool":
return {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": message["tool_call_id"],
"content": message["content"],
},
],
}
if message.get("tool_calls"):
tc = message["tool_calls"]
ret = {"role": "assistant", "content": []}
for tool_call in tc:
function = tool_call["function"]
arguments = json.loads(function["arguments"])
new_tool_use = {
"type": "tool_use",
"id": tool_call["id"],
"name": function["name"],
"input": arguments,
}
ret["content"].append(new_tool_use)
return ret
content = message.get("content")
if isinstance(content, str):
# fix empty text
if content == "":
content = "(empty)"
elif isinstance(content, list):
for item in content:
# fix empty text
if item["type"] == "text" and item["text"] == "":
item["text"] = "(empty)"
# handle image_url -> image conversion
if item["type"] == "image_url":
item["type"] = "image"
item["source"] = {
"type": "base64",
"media_type": "image/jpeg",
"data": item["image_url"]["url"].split(",")[1],
}
del item["image_url"]
# In the case where there's a single image in the list (like what
# would result from a UserImageRawFrame), ensure that the image
# comes before text, as recommended by Anthropic docs
# (https://docs.anthropic.com/en/docs/build-with-claude/vision#example-one-image)
image_indices = [i for i, item in enumerate(content) if item["type"] == "image"]
text_indices = [i for i, item in enumerate(content) if item["type"] == "text"]
if len(image_indices) == 1 and text_indices:
img_idx = image_indices[0]
first_txt_idx = text_indices[0]
if img_idx > first_txt_idx:
# Move image before the first text
image_item = content.pop(img_idx)
content.insert(first_txt_idx, image_item)
return message
def _with_cache_control_markers(self, messages: List[MessageParam]) -> List[MessageParam]:
"""Add cache control markers to messages for prompt caching.
Args:
messages: List of messages in Anthropic format.
Returns:
List of messages with cache control markers added.
"""
def add_cache_control_marker(message: MessageParam):
if isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
message["content"][-1]["cache_control"] = {"type": "ephemeral"}
try:
# Add cache control markers to the most recent two user messages.
# - The marker at the most recent user message tells Anthropic to
# cache the prompt up to that point.
# - The marker at the second-most-recent user message tells Anthropic
# to look up the cached prompt that goes up to that point (the
# point that *was* the last user message the previous turn).
# If we only added the marker to the last user message, we'd only
# ever be adding to the cache, never looking up from it.
# Why user messages? We're assuming that we're primarily running
# inference as soon as user turns come in. In Anthropic, turns
# strictly alternate between user and assistant.
messages_with_markers = copy.deepcopy(messages)
# Find the most recent two user messages
user_message_indices = []
for i in range(len(messages_with_markers) - 1, -1, -1):
if messages_with_markers[i]["role"] == "user":
user_message_indices.append(i)
if len(user_message_indices) == 2:
break
# Add cache control markers to the identified user messages
for index in user_message_indices:
add_cache_control_marker(messages_with_markers[index])
return messages_with_markers
except Exception as e:
logger.error(f"Error adding cache control marker: {e}")
return messages_with_markers
@staticmethod
def _to_anthropic_function_format(function: FunctionSchema) -> Dict[str, Any]:

View File

@@ -67,7 +67,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
return {
"system_instruction": messages.system_instruction,
"messages": messages.messages,
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
"tools": self.from_standard_tools(context.tools),
}
@@ -192,14 +192,14 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
def _from_standard_message(
self, message: LLMStandardMessage, already_have_system_instruction: bool
) -> Content | str:
"""Convert universal context message to Google Content object.
"""Convert standard universal context message to Google Content object.
Handles conversion of text, images, and function calls to Google's
format.
System instructions are returned as a plain string.
Args:
message: Message in universal context format.
message: Message in standard universal context format.
already_have_system_instruction: Whether we already have a system instruction
Returns:
@@ -308,5 +308,4 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
audio_bytes = base64.b64decode(input_audio["data"])
parts.append(Part(inline_data=Blob(mime_type="audio/wav", data=audio_bytes)))
message = Content(role=role, parts=parts)
return message
return Content(role=role, parts=parts)

View File

@@ -30,25 +30,17 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
"""Get the currently active LLM, if any."""
return self.strategy.active_service
async def run_inference(
self, context: LLMContext, system_instruction: Optional[str] = None
) -> Optional[str]:
async def run_inference(self, context: LLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context, using the currently active LLM.
Args:
context: The LLM context containing conversation history.
system_instruction: Optional system instruction to guide the LLM's
behavior. You could also (again, optionally) provide a system
instruction directly in the context. If both are provided, the
one in the context takes precedence.
Returns:
The LLM's response as a string, or None if no response is generated.
"""
if self.active_llm:
return await self.active_llm.run_inference(
context=context, system_instruction=system_instruction
)
return await self.active_llm.run_inference(context=context)
return None
def register_function(

View File

@@ -24,7 +24,10 @@ from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
from pipecat.adapters.services.anthropic_adapter import (
AnthropicLLMAdapter,
AnthropicLLMInvocationParams,
)
from pipecat.frames.frames import (
ErrorFrame,
Frame,
@@ -112,7 +115,12 @@ class AnthropicLLMService(LLMService):
"""Input parameters for Anthropic model inference.
Parameters:
enable_prompt_caching_beta: Whether to enable beta prompt caching feature.
enable_prompt_caching: Whether to enable the prompt caching feature.
enable_prompt_caching_beta (deprecated): Whether to enable the beta prompt caching feature.
.. deprecated:: 0.0.84
Use the `enable_prompt_caching` parameter instead.
max_tokens: Maximum tokens to generate. Must be at least 1.
temperature: Sampling temperature between 0.0 and 1.0.
top_k: Top-k sampling parameter.
@@ -120,13 +128,26 @@ class AnthropicLLMService(LLMService):
extra: Additional parameters to pass to the API.
"""
enable_prompt_caching_beta: Optional[bool] = False
enable_prompt_caching: Optional[bool] = None
enable_prompt_caching_beta: Optional[bool] = None
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
def model_post_init(self, __context):
"""Post-initialization to handle deprecated parameters."""
if self.enable_prompt_caching_beta is not None:
import warnings
warnings.simplefilter("always")
warnings.warn(
"enable_prompt_caching_beta is deprecated. Use enable_prompt_caching instead.",
DeprecationWarning,
stacklevel=2,
)
def __init__(
self,
*,
@@ -159,7 +180,15 @@ class AnthropicLLMService(LLMService):
self._retry_on_timeout = retry_on_timeout
self._settings = {
"max_tokens": params.max_tokens,
"enable_prompt_caching_beta": params.enable_prompt_caching_beta or False,
"enable_prompt_caching": (
params.enable_prompt_caching
if params.enable_prompt_caching is not None
else (
params.enable_prompt_caching_beta
if params.enable_prompt_caching_beta is not None
else False
)
),
"temperature": params.temperature,
"top_k": params.top_k,
"top_p": params.top_p,
@@ -199,34 +228,28 @@ class AnthropicLLMService(LLMService):
response = await api_call(**params)
return response
async def run_inference(
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
) -> Optional[str]:
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
Args:
context: The LLM context containing conversation history.
system_instruction: Optional system instruction to guide the LLM's
behavior. You could also (again, optionally) provide a system
instruction directly in the context. If both are provided, the
one in the context takes precedence.
Returns:
The LLM's response as a string, or None if no response is generated.
"""
messages = []
system = []
system = NOT_GIVEN
if isinstance(context, LLMContext):
# Future code will be something like this:
# adapter = self.get_llm_adapter()
# params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params(context)
# messages = params["messages"]
# system = params["system_instruction"]
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
params = adapter.get_llm_invocation_params(
context, enable_prompt_caching=self._settings["enable_prompt_caching"]
)
messages = params["messages"]
system = params["system"]
else:
context = AnthropicLLMContext.upgrade_to_anthropic(context)
messages = context.messages
system = getattr(context, "system", None) or system_instruction
system = getattr(context, "system", NOT_GIVEN)
# LLM completion
response = await self._client.messages.create(
@@ -239,15 +262,6 @@ class AnthropicLLMService(LLMService):
return response.content[0].text
@property
def enable_prompt_caching_beta(self) -> bool:
"""Check if prompt caching beta feature is enabled.
Returns:
True if prompt caching is enabled.
"""
return self._enable_prompt_caching_beta
def create_context_aggregator(
self,
context: OpenAILLMContext,
@@ -277,8 +291,31 @@ class AnthropicLLMService(LLMService):
assistant = AnthropicAssistantContextAggregator(context, params=assistant_params)
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
def _get_llm_invocation_params(
self, context: OpenAILLMContext | LLMContext
) -> AnthropicLLMInvocationParams:
# Universal LLMContext
if isinstance(context, LLMContext):
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
params = adapter.get_llm_invocation_params(
context, enable_prompt_caching=self._settings["enable_prompt_caching"]
)
return params
# Anthropic-specific context
messages = (
context.get_messages_with_cache_control_markers()
if self._settings["enable_prompt_caching"]
else context.messages
)
return AnthropicLLMInvocationParams(
system=context.system,
messages=messages,
tools=context.tools,
)
@traced_llm
async def _process_context(self, context: OpenAILLMContext):
async def _process_context(self, context: OpenAILLMContext | LLMContext):
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
# and use that estimate if we are interrupted, because we almost certainly won't
@@ -294,24 +331,22 @@ class AnthropicLLMService(LLMService):
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
params_from_context = self._get_llm_invocation_params(context)
if isinstance(context, LLMContext):
adapter = self.get_llm_adapter()
context_type_for_logging = "universal"
messages_for_logging = adapter.get_messages_for_logging(context)
else:
context_type_for_logging = "LLM-specific"
messages_for_logging = context.get_messages_for_logging()
logger.debug(
f"{self}: Generating chat [{context.system}] | {context.get_messages_for_logging()}"
f"{self}: Generating chat from {context_type_for_logging} context [{params_from_context['system']}] | {messages_for_logging}"
)
messages = context.messages
if self._settings["enable_prompt_caching_beta"]:
messages = context.get_messages_with_cache_control_markers()
api_call = self._client.messages.create
if self._settings["enable_prompt_caching_beta"]:
api_call = self._client.beta.prompt_caching.messages.create
await self.start_ttfb_metrics()
params = {
"tools": context.tools or [],
"system": context.system,
"messages": messages,
"model": self.model_name,
"max_tokens": self._settings["max_tokens"],
"stream": True,
@@ -320,9 +355,12 @@ class AnthropicLLMService(LLMService):
"top_p": self._settings["top_p"],
}
# Messages, system, tools
params.update(params_from_context)
params.update(self._settings["extra"])
response = await self._create_message_stream(api_call, params)
response = await self._create_message_stream(self._client.messages.create, params)
await self.stop_ttfb_metrics()
@@ -405,7 +443,10 @@ class AnthropicLLMService(LLMService):
prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
)
if total_input_tokens >= 1024:
context.turns_above_cache_threshold += 1
if hasattr(
context, "turns_above_cache_threshold"
): # LLMContext doesn't have this attribute
context.turns_above_cache_threshold += 1
await self.run_function_calls(function_calls)
@@ -451,7 +492,7 @@ class AnthropicLLMService(LLMService):
if isinstance(frame, OpenAILLMContextFrame):
context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context)
elif isinstance(frame, LLMContextFrame):
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
context = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = AnthropicLLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
@@ -464,7 +505,7 @@ class AnthropicLLMService(LLMService):
await self._update_settings(frame.settings)
elif isinstance(frame, LLMEnablePromptCachingFrame):
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
self._settings["enable_prompt_caching_beta"] = frame.enable
self._settings["enable_prompt_caching"] = frame.enable
else:
await self.push_frame(frame, direction)

View File

@@ -792,17 +792,11 @@ class AWSBedrockLLMService(LLMService):
"""
return True
async def run_inference(
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
) -> Optional[str]:
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
Args:
context: The LLM context containing conversation history.
system_instruction: Optional system instruction to guide the LLM's
behavior. You could also (again, optionally) provide a system
instruction directly in the context. If both are provided, the
one in the context takes precedence.
Returns:
The LLM's response as a string, or None if no response is generated.
@@ -815,14 +809,14 @@ class AWSBedrockLLMService(LLMService):
# adapter = self.get_llm_adapter()
# params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
# messages = params["messages"]
# system = params["system_instruction"]
# system = params["system_instruction"] # [{"text": "system message"}]
raise NotImplementedError(
"Universal LLMContext is not yet supported for AWS Bedrock."
)
else:
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
messages = context.messages
system = getattr(context, "system", None) or system_instruction
system = getattr(context, "system", None) # [{"text": "system message"}]
# Determine if we're using Claude or Nova based on model ID
model_id = self.model_name
@@ -839,7 +833,7 @@ class AWSBedrockLLMService(LLMService):
}
if system:
request_params["system"] = [{"text": system}]
request_params["system"] = system
async with self._aws_session.client(
service_name="bedrock-runtime", **self._aws_params

View File

@@ -733,17 +733,11 @@ class GoogleLLMService(LLMService):
def _create_client(self, api_key: str, http_options: Optional[HttpOptions] = None):
self._client = genai.Client(api_key=api_key, http_options=http_options)
async def run_inference(
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
) -> Optional[str]:
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
Args:
context: The LLM context containing conversation history.
system_instruction: Optional system instruction to guide the LLM's
behavior. You could also (again, optionally) provide a system
instruction directly in the context. If both are provided, the
one in the context takes precedence.
Returns:
The LLM's response as a string, or None if no response is generated.
@@ -758,7 +752,7 @@ class GoogleLLMService(LLMService):
else:
context = GoogleLLMContext.upgrade_to_google(context)
messages = context.messages
system = getattr(context, "system_message", None) or system_instruction
system = getattr(context, "system_message", None)
generation_config = GenerateContentConfig(system_instruction=system)
@@ -858,8 +852,7 @@ class GoogleLLMService(LLMService):
self, context: OpenAILLMContext
) -> AsyncIterator[GenerateContentResponse]:
logger.debug(
# f"{self}: Generating chat [{self._system_instruction}] | {context.get_messages_for_logging()}"
f"{self}: Generating chat from OpenAI context {context.get_messages_for_logging()}"
f"{self}: Generating chat from LLM-specific context [{context.system_message}] | {context.get_messages_for_logging()}"
)
params = GeminiLLMInvocationParams(
@@ -874,13 +867,12 @@ class GoogleLLMService(LLMService):
self, context: LLMContext
) -> AsyncIterator[GenerateContentResponse]:
adapter = self.get_llm_adapter()
logger.debug(
# f"{self}: Generating chat [{self._system_instruction}] | {context.get_messages_for_logging()}"
f"{self}: Generating chat from universal context {adapter.get_messages_for_logging(context)}"
)
params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(context)
logger.debug(
f"{self}: Generating chat from universal context [{params['system_instruction']}] | {adapter.get_messages_for_logging(context)}"
)
return await self._stream_content(params)
@traced_llm

View File

@@ -195,18 +195,13 @@ class LLMService(AIService):
"""
return self._adapter
async def run_inference(
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
) -> Optional[str]:
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
Must be implemented by subclasses.
Args:
context: The LLM context containing conversation history.
system_instruction: Optional system instruction to guide the LLM's
behavior. You could also (again, optionally) provide a system
instruction directly in the context.
Returns:
The LLM's response as a string, or None if no response is generated.

View File

@@ -245,16 +245,11 @@ class BaseOpenAILLMService(LLMService):
params.update(self._settings["extra"])
return params
async def run_inference(
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
) -> Optional[str]:
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
Args:
context: The LLM context containing conversation history.
system_instruction: Optional system instruction to guide the LLM's
behavior. You could also (again, optionally) provide a system
instruction directly in the context.
Returns:
The LLM's response as a string, or None if no response is generated.
@@ -279,7 +274,7 @@ class BaseOpenAILLMService(LLMService):
self, context: OpenAILLMContext
) -> AsyncStream[ChatCompletionChunk]:
logger.debug(
f"{self}: Generating chat from OpenAI context {context.get_messages_for_logging()}"
f"{self}: Generating chat from LLM-specific context {context.get_messages_for_logging()}"
)
messages: List[ChatCompletionMessageParam] = context.get_messages()