Implement LLMService.create_llm_specific_message() so that users don't need to just know what value of llm to provide to the LLMSpecificMessage constructor
This commit is contained in:
@@ -16,7 +16,12 @@ from typing import Any, Dict, Generic, List, TypeVar
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
# Should be a TypedDict
|
||||
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
|
||||
@@ -38,6 +43,16 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
Subclasses must implement provider-specific conversion logic.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for this LLM provider.
|
||||
|
||||
Returns:
|
||||
The identifier string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
|
||||
"""Get provider-specific LLM invocation parameters from a universal LLM context.
|
||||
@@ -76,6 +91,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message)
|
||||
|
||||
def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
"""Get messages from the LLM context, including standard and LLM-specific messages.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages including standard and LLM-specific messages.
|
||||
"""
|
||||
return context.get_messages(self.id_for_llm_specific_messages)
|
||||
|
||||
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
|
||||
"""Convert tools from standard format to provider format.
|
||||
|
||||
|
||||
@@ -42,6 +42,11 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
to the specific format required by Anthropic's Claude models for function calling.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Anthropic."""
|
||||
return "anthropic"
|
||||
|
||||
def get_llm_invocation_params(
|
||||
self, context: LLMContext, enable_prompt_caching: bool
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
@@ -54,7 +59,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking Anthropic's LLM API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": (
|
||||
@@ -78,7 +83,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Anthropic.
|
||||
"""
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -92,9 +97,6 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("anthropic")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
@@ -31,6 +31,11 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
specific function-calling format, enabling tool use with Nova Sonic models.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
|
||||
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -42,6 +42,11 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
into AWS Bedrock's expected tool format for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Bedrock."""
|
||||
return "aws"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
|
||||
"""Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -51,7 +56,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Bedrock's LLM API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": messages.messages,
|
||||
@@ -75,7 +80,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about AWS Bedrock.
|
||||
"""
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -89,9 +94,6 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("anthropic")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
@@ -54,6 +54,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Google."""
|
||||
return "google"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
|
||||
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -63,7 +68,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for Gemini's API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
@@ -103,7 +108,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Gemini.
|
||||
"""
|
||||
# Get messages in Gemini's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -119,9 +124,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
messages_for_logging.append(obj)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("google")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Google-formatted messages converted from universal context."""
|
||||
|
||||
@@ -48,6 +48,11 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging about OpenAI.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI."""
|
||||
return "openai"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
|
||||
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -58,7 +63,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
Dictionary of parameters for OpenAI's ChatCompletion API.
|
||||
"""
|
||||
return {
|
||||
"messages": self._from_universal_context_messages(self._get_messages(context)),
|
||||
"messages": self._from_universal_context_messages(self.get_messages(context)),
|
||||
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
"tool_choice": context.tool_choice,
|
||||
@@ -92,7 +97,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
List of messages in a format ready for logging about OpenAI.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self._get_messages(context):
|
||||
for message in self.get_messages(context):
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
@@ -105,9 +110,6 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("openai")
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, messages: List[LLMContextMessage]
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
|
||||
@@ -30,6 +30,11 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
OpenAI's Realtime API for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams:
|
||||
"""Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -195,6 +195,17 @@ class LLMService(AIService):
|
||||
"""
|
||||
return self._adapter
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return self.get_llm_adapter().create_llm_specific_message(message)
|
||||
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
|
||||
@@ -11,37 +11,33 @@ These tests focus specifically on the "messages" field generation for different
|
||||
|
||||
For OpenAI adapter:
|
||||
1. LLMStandardMessage objects are passed through unchanged
|
||||
2. LLMSpecificMessage objects with llm='openai' are included and their content extracted
|
||||
3. LLMSpecificMessage objects with llm != 'openai' are filtered out
|
||||
4. Complex message structures (like multi-part content) are preserved
|
||||
5. System instructions are preserved throughout messages at any position
|
||||
2. LLMSpecificMessage objects with llm='openai' are included and others are filtered out
|
||||
3. Complex message structures (like multi-part content) are preserved
|
||||
4. System instructions are preserved throughout messages at any position
|
||||
|
||||
For Gemini adapter:
|
||||
1. LLMStandardMessage objects are converted to Gemini Content format
|
||||
2. LLMSpecificMessage objects with llm='google' are included unchanged
|
||||
3. LLMSpecificMessage objects with llm != 'google' are filtered out
|
||||
4. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
|
||||
5. System messages are extracted as system_instruction (without duplication)
|
||||
6. Single system instruction is converted to user message when no other messages exist
|
||||
7. Multiple system instructions: first extracted, later ones converted to user messages
|
||||
2. LLMSpecificMessage objects with llm='google' are included and others are filtered out
|
||||
3. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
|
||||
4. System messages are extracted as system_instruction (without duplication)
|
||||
5. Single system instruction is converted to user message when no other messages exist
|
||||
6. Multiple system instructions: first extracted, later ones converted to user messages
|
||||
|
||||
For Anthropic adapter:
|
||||
1. LLMStandardMessage objects are converted to Anthropic MessageParam format
|
||||
2. LLMSpecificMessage objects with llm='anthropic' are included unchanged
|
||||
3. LLMSpecificMessage objects with llm != 'anthropic' are filtered out
|
||||
4. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
|
||||
5. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
6. Consecutive messages with same role are merged into multi-content-block messages
|
||||
7. Empty text content is converted to "(empty)"
|
||||
2. LLMSpecificMessage objects with llm='anthropic' are included and others are filtered out
|
||||
3. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
|
||||
4. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
5. Consecutive messages with same role are merged into multi-content-block messages
|
||||
6. Empty text content is converted to "(empty)"
|
||||
|
||||
For AWS Bedrock adapter:
|
||||
1. LLMStandardMessage objects are converted to AWS Bedrock format
|
||||
2. LLMSpecificMessage objects with llm='anthropic' are included unchanged (uses Anthropic format)
|
||||
3. LLMSpecificMessage objects with llm != 'anthropic' are filtered out
|
||||
4. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
|
||||
5. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
6. Consecutive messages with same role are merged into multi-content-block messages
|
||||
7. Empty text content is converted to "(empty)"
|
||||
2. LLMSpecificMessage objects with llm='aws' are included and others are filtered out
|
||||
3. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
|
||||
4. System messages: first extracted as system parameter, later ones converted to user messages
|
||||
5. Consecutive messages with same role are merged into multi-content-block messages
|
||||
6. Empty text content is converted to "(empty)"
|
||||
"""
|
||||
|
||||
import unittest
|
||||
@@ -89,51 +85,20 @@ class TestOpenAIGetLLMInvocationParams(unittest.TestCase):
|
||||
self.assertEqual(params["messages"][1]["content"], "Hello, how are you?")
|
||||
self.assertEqual(params["messages"][2]["content"], "I'm doing well, thank you for asking!")
|
||||
|
||||
def test_openai_specific_messages_included(self):
|
||||
"""Test that LLMSpecificMessage objects with llm='openai' are included."""
|
||||
# Create a mix of standard and OpenAI-specific messages
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
LLMSpecificMessage(
|
||||
llm="openai", message={"role": "user", "content": "OpenAI specific message"}
|
||||
),
|
||||
{"role": "assistant", "content": "Standard response"},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify all messages are included (OpenAI-specific should be included)
|
||||
self.assertEqual(len(params["messages"]), 3)
|
||||
|
||||
# First message should be standard
|
||||
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
|
||||
|
||||
# Second message should be the OpenAI-specific one
|
||||
self.assertEqual(
|
||||
params["messages"][1], {"role": "user", "content": "OpenAI specific message"}
|
||||
)
|
||||
|
||||
# Third message should be standard
|
||||
self.assertEqual(params["messages"][2]["content"], "Standard response")
|
||||
|
||||
def test_non_openai_specific_messages_filtered_out(self):
|
||||
"""Test that LLMSpecificMessage objects with llm != 'openai' are filtered out."""
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that OpenAI-specific messages are included and others are filtered out."""
|
||||
# Create messages with different LLM-specific ones
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
LLMSpecificMessage(
|
||||
llm="anthropic", message={"role": "user", "content": "Anthropic specific message"}
|
||||
AnthropicLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Anthropic specific message"}
|
||||
),
|
||||
LLMSpecificMessage(
|
||||
llm="gemini", message={"role": "user", "content": "Gemini specific message"}
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Gemini specific message"}
|
||||
),
|
||||
{"role": "user", "content": "Standard user message"},
|
||||
LLMSpecificMessage(
|
||||
llm="openai", message={"role": "assistant", "content": "OpenAI specific response"}
|
||||
self.adapter.create_llm_specific_message(
|
||||
{"role": "assistant", "content": "OpenAI specific response"}
|
||||
),
|
||||
]
|
||||
|
||||
@@ -291,53 +256,20 @@ class TestGeminiGetLLMInvocationParams(unittest.TestCase):
|
||||
self.assertEqual(len(model_msg.parts), 1)
|
||||
self.assertEqual(model_msg.parts[0].text, "I'm doing well, thank you for asking!")
|
||||
|
||||
def test_gemini_specific_messages_included(self):
|
||||
"""Test that LLMSpecificMessage objects with llm='google' are included unchanged."""
|
||||
# Create a Gemini-specific message
|
||||
gemini_message = Content(role="user", parts=[Part(text="Gemini specific message")])
|
||||
|
||||
# Create a mix of standard and Gemini-specific messages
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
LLMSpecificMessage(llm="google", message=gemini_message),
|
||||
{"role": "assistant", "content": "Standard response"},
|
||||
]
|
||||
|
||||
# Create context with these messages
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify system instruction
|
||||
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
|
||||
|
||||
# Verify messages (2 total: gemini-specific user + converted model)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
|
||||
# First message should be the Gemini-specific one (unchanged)
|
||||
self.assertEqual(params["messages"][0], gemini_message)
|
||||
self.assertEqual(params["messages"][0].parts[0].text, "Gemini specific message")
|
||||
|
||||
# Second message should be converted standard message
|
||||
self.assertEqual(params["messages"][1].role, "model")
|
||||
self.assertEqual(params["messages"][1].parts[0].text, "Standard response")
|
||||
|
||||
def test_non_gemini_specific_messages_filtered_out(self):
|
||||
"""Test that LLMSpecificMessage objects with llm != 'google' are filtered out."""
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that Gemini-specific messages are included and others are filtered out."""
|
||||
# Create messages with different LLM-specific ones
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
LLMSpecificMessage(
|
||||
llm="openai", message={"role": "user", "content": "OpenAI specific message"}
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific message"}
|
||||
),
|
||||
LLMSpecificMessage(
|
||||
llm="anthropic", message={"role": "user", "content": "Anthropic specific message"}
|
||||
AnthropicLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Anthropic specific message"}
|
||||
),
|
||||
{"role": "user", "content": "Standard user message"},
|
||||
LLMSpecificMessage(
|
||||
llm="google",
|
||||
message=Content(role="model", parts=[Part(text="Gemini specific response")]),
|
||||
self.adapter.create_llm_specific_message(
|
||||
Content(role="model", parts=[Part(text="Gemini specific response")]),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -584,8 +516,8 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
|
||||
self.assertEqual(assistant_msg["role"], "assistant")
|
||||
self.assertEqual(assistant_msg["content"], "I'm doing well, thank you!")
|
||||
|
||||
def test_anthropic_specific_messages_included_unchanged(self):
|
||||
"""Test that LLMSpecificMessage objects with llm='anthropic' are included unchanged."""
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that Anthropic-specific messages are included and others are filtered out."""
|
||||
# Create anthropic-specific message content
|
||||
anthropic_message_content = {
|
||||
"role": "user",
|
||||
@@ -599,36 +531,14 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
|
||||
}
|
||||
|
||||
messages = [
|
||||
LLMSpecificMessage(llm="anthropic", message=anthropic_message_content),
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Verify the anthropic-specific message is preserved
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
anthropic_msg = params["messages"][0]
|
||||
self.assertEqual(anthropic_msg["role"], "user")
|
||||
self.assertIsInstance(anthropic_msg["content"], list)
|
||||
self.assertEqual(len(anthropic_msg["content"]), 2)
|
||||
self.assertEqual(anthropic_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(anthropic_msg["content"][0]["text"], "Hello")
|
||||
self.assertEqual(anthropic_msg["content"][1]["type"], "image")
|
||||
|
||||
def test_non_anthropic_specific_messages_filtered_out(self):
|
||||
"""Test that LLMSpecificMessage objects with llm != 'anthropic' are filtered out."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
LLMSpecificMessage(
|
||||
llm="openai", message={"role": "user", "content": "OpenAI specific"}
|
||||
{"role": "user", "content": "Standard message"},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific"}
|
||||
),
|
||||
LLMSpecificMessage(
|
||||
llm="google", message={"role": "user", "content": "Google specific"}
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Google specific"}
|
||||
),
|
||||
self.adapter.create_llm_specific_message(anthropic_message_content),
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
@@ -638,9 +548,23 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
|
||||
|
||||
# Should only have the 2 standard messages (openai and google specific filtered out)
|
||||
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
|
||||
# (openai and google specific filtered out, standard + anthropic-specific merged)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
self.assertEqual(params["messages"][0]["content"], "Hello")
|
||||
|
||||
# First message: merged user message (standard + anthropic-specific)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
# Should have 3 content blocks: standard text + anthropic text + anthropic image
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
|
||||
self.assertEqual(user_msg["content"][1]["type"], "text")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Hello")
|
||||
self.assertEqual(user_msg["content"][2]["type"], "image")
|
||||
|
||||
# Second message: standard response
|
||||
self.assertEqual(params["messages"][1]["content"], "Response")
|
||||
|
||||
def test_consecutive_same_role_messages_merged(self):
|
||||
@@ -857,10 +781,10 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
|
||||
self.assertEqual(len(assistant_msg["content"]), 1)
|
||||
self.assertEqual(assistant_msg["content"][0]["text"], "I'm doing well, thank you!")
|
||||
|
||||
def test_anthropic_specific_messages_included_unchanged(self):
|
||||
"""Test that LLMSpecificMessage objects with llm='anthropic' are included unchanged (AWS Bedrock uses Anthropic format)."""
|
||||
# Create anthropic-specific message content (which is what AWS Bedrock uses)
|
||||
anthropic_message_content = {
|
||||
def test_llm_specific_message_filtering(self):
|
||||
"""Test that AWS-specific messages are included and others are filtered out."""
|
||||
# Create aws-specific message content (which is what AWS Bedrock uses)
|
||||
aws_message_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello"},
|
||||
@@ -869,35 +793,14 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
|
||||
}
|
||||
|
||||
messages = [
|
||||
LLMSpecificMessage(llm="anthropic", message=anthropic_message_content),
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
# Create context
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Verify the anthropic-specific message is preserved
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
anthropic_msg = params["messages"][0]
|
||||
self.assertEqual(anthropic_msg["role"], "user")
|
||||
self.assertIsInstance(anthropic_msg["content"], list)
|
||||
self.assertEqual(len(anthropic_msg["content"]), 2)
|
||||
self.assertEqual(anthropic_msg["content"][0]["text"], "Hello")
|
||||
self.assertIn("image", anthropic_msg["content"][1])
|
||||
|
||||
def test_non_anthropic_specific_messages_filtered_out(self):
|
||||
"""Test that LLMSpecificMessage objects with llm != 'anthropic' are filtered out."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
LLMSpecificMessage(
|
||||
llm="openai", message={"role": "user", "content": "OpenAI specific"}
|
||||
{"role": "user", "content": "Standard message"},
|
||||
OpenAILLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "OpenAI specific"}
|
||||
),
|
||||
LLMSpecificMessage(
|
||||
llm="google", message={"role": "user", "content": "Google specific"}
|
||||
GeminiLLMAdapter().create_llm_specific_message(
|
||||
{"role": "user", "content": "Google specific"}
|
||||
),
|
||||
self.adapter.create_llm_specific_message(message=aws_message_content),
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
@@ -907,9 +810,21 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
|
||||
# Get invocation params
|
||||
params = self.adapter.get_llm_invocation_params(context)
|
||||
|
||||
# Should only have the 2 standard messages (openai and google specific filtered out)
|
||||
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
|
||||
# (openai and google specific filtered out, standard + aws-specific merged)
|
||||
self.assertEqual(len(params["messages"]), 2)
|
||||
self.assertEqual(params["messages"][0]["content"][0]["text"], "Hello")
|
||||
|
||||
# First message: merged user message (standard + aws-specific)
|
||||
user_msg = params["messages"][0]
|
||||
self.assertEqual(user_msg["role"], "user")
|
||||
self.assertIsInstance(user_msg["content"], list)
|
||||
# Should have 3 content blocks: standard text + aws text + aws image
|
||||
self.assertEqual(len(user_msg["content"]), 3)
|
||||
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
|
||||
self.assertEqual(user_msg["content"][1]["text"], "Hello")
|
||||
self.assertIn("image", user_msg["content"][2])
|
||||
|
||||
# Second message: standard response
|
||||
self.assertEqual(params["messages"][1]["content"][0]["text"], "Response")
|
||||
|
||||
def test_consecutive_same_role_messages_merged(self):
|
||||
|
||||
Reference in New Issue
Block a user