Implement LLMService.create_llm_specific_message() so that users don't need to just know what value of llm to provide to the LLMSpecificMessage constructor

This commit is contained in:
Paul Kompfner
2025-09-15 13:25:04 -04:00
parent 999e88c942
commit fe42187dc1
9 changed files with 170 additions and 189 deletions

View File

@@ -16,7 +16,12 @@ from typing import Any, Dict, Generic, List, TypeVar
from loguru import logger
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMContextMessage,
LLMSpecificMessage,
NotGiven,
)
# Should be a TypedDict
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
@@ -38,6 +43,16 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
Subclasses must implement provider-specific conversion logic.
"""
@property
@abstractmethod
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for this LLM provider.
Returns:
The identifier string.
"""
pass
@abstractmethod
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
"""Get provider-specific LLM invocation parameters from a universal LLM context.
@@ -76,6 +91,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
"""
pass
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
Args:
message: The message content.
Returns:
A LLMSpecificMessage instance.
"""
return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message)
def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
"""Get messages from the LLM context, including standard and LLM-specific messages.
Args:
context: The LLM context containing messages.
Returns:
List of messages including standard and LLM-specific messages.
"""
return context.get_messages(self.id_for_llm_specific_messages)
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
"""Convert tools from standard format to provider format.

View File

@@ -42,6 +42,11 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
to the specific format required by Anthropic's Claude models for function calling.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for Anthropic."""
return "anthropic"
def get_llm_invocation_params(
self, context: LLMContext, enable_prompt_caching: bool
) -> AnthropicLLMInvocationParams:
@@ -54,7 +59,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
Returns:
Dictionary of parameters for invoking Anthropic's LLM API.
"""
messages = self._from_universal_context_messages(self._get_messages(context))
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system": messages.system,
"messages": (
@@ -78,7 +83,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
List of messages in a format ready for logging about Anthropic.
"""
# Get messages in Anthropic's format
messages = self._from_universal_context_messages(self._get_messages(context)).messages
messages = self._from_universal_context_messages(self.get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
@@ -92,9 +97,6 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
messages_for_logging.append(msg)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("anthropic")
@dataclass
class ConvertedMessages:
"""Container for Anthropic-formatted messages converted from universal context."""

View File

@@ -31,6 +31,11 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
specific function-calling format, enabling tool use with Nova Sonic models.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.

View File

@@ -42,6 +42,11 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
into AWS Bedrock's expected tool format for function calling capabilities.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for AWS Bedrock."""
return "aws"
def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
"""Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context.
@@ -51,7 +56,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
Returns:
Dictionary of parameters for invoking AWS Bedrock's LLM API.
"""
messages = self._from_universal_context_messages(self._get_messages(context))
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system": messages.system,
"messages": messages.messages,
@@ -75,7 +80,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
List of messages in a format ready for logging about AWS Bedrock.
"""
# Get messages in Anthropic's format
messages = self._from_universal_context_messages(self._get_messages(context)).messages
messages = self._from_universal_context_messages(self.get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
@@ -89,9 +94,6 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
messages_for_logging.append(msg)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("anthropic")
@dataclass
class ConvertedMessages:
"""Container for Anthropic-formatted messages converted from universal context."""

View File

@@ -54,6 +54,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for Google."""
return "google"
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
@@ -63,7 +68,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
Returns:
Dictionary of parameters for Gemini's API.
"""
messages = self._from_universal_context_messages(self._get_messages(context))
messages = self._from_universal_context_messages(self.get_messages(context))
return {
"system_instruction": messages.system_instruction,
"messages": messages.messages,
@@ -103,7 +108,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
List of messages in a format ready for logging about Gemini.
"""
# Get messages in Gemini's format
messages = self._from_universal_context_messages(self._get_messages(context)).messages
messages = self._from_universal_context_messages(self.get_messages(context)).messages
# Sanitize messages for logging
messages_for_logging = []
@@ -119,9 +124,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
messages_for_logging.append(obj)
return messages_for_logging
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("google")
@dataclass
class ConvertedMessages:
"""Container for Google-formatted messages converted from universal context."""

View File

@@ -48,6 +48,11 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
- Extracting and sanitizing messages from the LLM context for logging about OpenAI.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for OpenAI."""
return "openai"
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
@@ -58,7 +63,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
Dictionary of parameters for OpenAI's ChatCompletion API.
"""
return {
"messages": self._from_universal_context_messages(self._get_messages(context)),
"messages": self._from_universal_context_messages(self.get_messages(context)),
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
"tools": self.from_standard_tools(context.tools),
"tool_choice": context.tool_choice,
@@ -92,7 +97,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
List of messages in a format ready for logging about OpenAI.
"""
msgs = []
for message in self._get_messages(context):
for message in self.get_messages(context):
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
@@ -105,9 +110,6 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
msgs.append(msg)
return msgs
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
return context.get_messages("openai")
def _from_universal_context_messages(
self, messages: List[LLMContextMessage]
) -> List[ChatCompletionMessageParam]:

View File

@@ -30,6 +30,11 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
OpenAI's Realtime API for function calling capabilities.
"""
@property
def id_for_llm_specific_messages(self) -> str:
"""Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime."""
raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.")
def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams:
"""Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context.

View File

@@ -44,7 +44,7 @@ from pipecat.frames.frames import (
StartFrame,
UserImageRequestFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMUserAggregatorParams,
@@ -195,6 +195,17 @@ class LLMService(AIService):
"""
return self._adapter
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
Args:
message: The message content.
Returns:
A LLMSpecificMessage instance.
"""
return self.get_llm_adapter().create_llm_specific_message(message)
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.

View File

@@ -11,37 +11,33 @@ These tests focus specifically on the "messages" field generation for different
For OpenAI adapter:
1. LLMStandardMessage objects are passed through unchanged
2. LLMSpecificMessage objects with llm='openai' are included and their content extracted
3. LLMSpecificMessage objects with llm != 'openai' are filtered out
4. Complex message structures (like multi-part content) are preserved
5. System instructions are preserved throughout messages at any position
2. LLMSpecificMessage objects with llm='openai' are included and others are filtered out
3. Complex message structures (like multi-part content) are preserved
4. System instructions are preserved throughout messages at any position
For Gemini adapter:
1. LLMStandardMessage objects are converted to Gemini Content format
2. LLMSpecificMessage objects with llm='google' are included unchanged
3. LLMSpecificMessage objects with llm != 'google' are filtered out
4. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
5. System messages are extracted as system_instruction (without duplication)
6. Single system instruction is converted to user message when no other messages exist
7. Multiple system instructions: first extracted, later ones converted to user messages
2. LLMSpecificMessage objects with llm='google' are included and others are filtered out
3. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format
4. System messages are extracted as system_instruction (without duplication)
5. Single system instruction is converted to user message when no other messages exist
6. Multiple system instructions: first extracted, later ones converted to user messages
For Anthropic adapter:
1. LLMStandardMessage objects are converted to Anthropic MessageParam format
2. LLMSpecificMessage objects with llm='anthropic' are included unchanged
3. LLMSpecificMessage objects with llm != 'anthropic' are filtered out
4. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
5. System messages: first extracted as system parameter, later ones converted to user messages
6. Consecutive messages with same role are merged into multi-content-block messages
7. Empty text content is converted to "(empty)"
2. LLMSpecificMessage objects with llm='anthropic' are included and others are filtered out
3. Complex message structures (image, multi-text) are converted to appropriate Anthropic format
4. System messages: first extracted as system parameter, later ones converted to user messages
5. Consecutive messages with same role are merged into multi-content-block messages
6. Empty text content is converted to "(empty)"
For AWS Bedrock adapter:
1. LLMStandardMessage objects are converted to AWS Bedrock format
2. LLMSpecificMessage objects with llm='anthropic' are included unchanged (uses Anthropic format)
3. LLMSpecificMessage objects with llm != 'anthropic' are filtered out
4. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
5. System messages: first extracted as system parameter, later ones converted to user messages
6. Consecutive messages with same role are merged into multi-content-block messages
7. Empty text content is converted to "(empty)"
2. LLMSpecificMessage objects with llm='aws' are included and others are filtered out
3. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format
4. System messages: first extracted as system parameter, later ones converted to user messages
5. Consecutive messages with same role are merged into multi-content-block messages
6. Empty text content is converted to "(empty)"
"""
import unittest
@@ -89,51 +85,20 @@ class TestOpenAIGetLLMInvocationParams(unittest.TestCase):
self.assertEqual(params["messages"][1]["content"], "Hello, how are you?")
self.assertEqual(params["messages"][2]["content"], "I'm doing well, thank you for asking!")
def test_openai_specific_messages_included(self):
"""Test that LLMSpecificMessage objects with llm='openai' are included."""
# Create a mix of standard and OpenAI-specific messages
messages = [
{"role": "system", "content": "You are a helpful assistant."},
LLMSpecificMessage(
llm="openai", message={"role": "user", "content": "OpenAI specific message"}
),
{"role": "assistant", "content": "Standard response"},
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify all messages are included (OpenAI-specific should be included)
self.assertEqual(len(params["messages"]), 3)
# First message should be standard
self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.")
# Second message should be the OpenAI-specific one
self.assertEqual(
params["messages"][1], {"role": "user", "content": "OpenAI specific message"}
)
# Third message should be standard
self.assertEqual(params["messages"][2]["content"], "Standard response")
def test_non_openai_specific_messages_filtered_out(self):
"""Test that LLMSpecificMessage objects with llm != 'openai' are filtered out."""
def test_llm_specific_message_filtering(self):
"""Test that OpenAI-specific messages are included and others are filtered out."""
# Create messages with different LLM-specific ones
messages = [
{"role": "system", "content": "You are a helpful assistant."},
LLMSpecificMessage(
llm="anthropic", message={"role": "user", "content": "Anthropic specific message"}
AnthropicLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Anthropic specific message"}
),
LLMSpecificMessage(
llm="gemini", message={"role": "user", "content": "Gemini specific message"}
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Gemini specific message"}
),
{"role": "user", "content": "Standard user message"},
LLMSpecificMessage(
llm="openai", message={"role": "assistant", "content": "OpenAI specific response"}
self.adapter.create_llm_specific_message(
{"role": "assistant", "content": "OpenAI specific response"}
),
]
@@ -291,53 +256,20 @@ class TestGeminiGetLLMInvocationParams(unittest.TestCase):
self.assertEqual(len(model_msg.parts), 1)
self.assertEqual(model_msg.parts[0].text, "I'm doing well, thank you for asking!")
def test_gemini_specific_messages_included(self):
"""Test that LLMSpecificMessage objects with llm='google' are included unchanged."""
# Create a Gemini-specific message
gemini_message = Content(role="user", parts=[Part(text="Gemini specific message")])
# Create a mix of standard and Gemini-specific messages
messages = [
{"role": "system", "content": "You are a helpful assistant."},
LLMSpecificMessage(llm="google", message=gemini_message),
{"role": "assistant", "content": "Standard response"},
]
# Create context with these messages
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify system instruction
self.assertEqual(params["system_instruction"], "You are a helpful assistant.")
# Verify messages (2 total: gemini-specific user + converted model)
self.assertEqual(len(params["messages"]), 2)
# First message should be the Gemini-specific one (unchanged)
self.assertEqual(params["messages"][0], gemini_message)
self.assertEqual(params["messages"][0].parts[0].text, "Gemini specific message")
# Second message should be converted standard message
self.assertEqual(params["messages"][1].role, "model")
self.assertEqual(params["messages"][1].parts[0].text, "Standard response")
def test_non_gemini_specific_messages_filtered_out(self):
"""Test that LLMSpecificMessage objects with llm != 'google' are filtered out."""
def test_llm_specific_message_filtering(self):
"""Test that Gemini-specific messages are included and others are filtered out."""
# Create messages with different LLM-specific ones
messages = [
{"role": "system", "content": "You are a helpful assistant."},
LLMSpecificMessage(
llm="openai", message={"role": "user", "content": "OpenAI specific message"}
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific message"}
),
LLMSpecificMessage(
llm="anthropic", message={"role": "user", "content": "Anthropic specific message"}
AnthropicLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Anthropic specific message"}
),
{"role": "user", "content": "Standard user message"},
LLMSpecificMessage(
llm="google",
message=Content(role="model", parts=[Part(text="Gemini specific response")]),
self.adapter.create_llm_specific_message(
Content(role="model", parts=[Part(text="Gemini specific response")]),
),
]
@@ -584,8 +516,8 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
self.assertEqual(assistant_msg["role"], "assistant")
self.assertEqual(assistant_msg["content"], "I'm doing well, thank you!")
def test_anthropic_specific_messages_included_unchanged(self):
"""Test that LLMSpecificMessage objects with llm='anthropic' are included unchanged."""
def test_llm_specific_message_filtering(self):
"""Test that Anthropic-specific messages are included and others are filtered out."""
# Create anthropic-specific message content
anthropic_message_content = {
"role": "user",
@@ -599,36 +531,14 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
}
messages = [
LLMSpecificMessage(llm="anthropic", message=anthropic_message_content),
{"role": "assistant", "content": "Hi there!"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Verify the anthropic-specific message is preserved
self.assertEqual(len(params["messages"]), 2)
anthropic_msg = params["messages"][0]
self.assertEqual(anthropic_msg["role"], "user")
self.assertIsInstance(anthropic_msg["content"], list)
self.assertEqual(len(anthropic_msg["content"]), 2)
self.assertEqual(anthropic_msg["content"][0]["type"], "text")
self.assertEqual(anthropic_msg["content"][0]["text"], "Hello")
self.assertEqual(anthropic_msg["content"][1]["type"], "image")
def test_non_anthropic_specific_messages_filtered_out(self):
"""Test that LLMSpecificMessage objects with llm != 'anthropic' are filtered out."""
messages = [
{"role": "user", "content": "Hello"},
LLMSpecificMessage(
llm="openai", message={"role": "user", "content": "OpenAI specific"}
{"role": "user", "content": "Standard message"},
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific"}
),
LLMSpecificMessage(
llm="google", message={"role": "user", "content": "Google specific"}
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Google specific"}
),
self.adapter.create_llm_specific_message(anthropic_message_content),
{"role": "assistant", "content": "Response"},
]
@@ -638,9 +548,23 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase):
# Get invocation params
params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False)
# Should only have the 2 standard messages (openai and google specific filtered out)
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
# (openai and google specific filtered out, standard + anthropic-specific merged)
self.assertEqual(len(params["messages"]), 2)
self.assertEqual(params["messages"][0]["content"], "Hello")
# First message: merged user message (standard + anthropic-specific)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
# Should have 3 content blocks: standard text + anthropic text + anthropic image
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["type"], "text")
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
self.assertEqual(user_msg["content"][1]["type"], "text")
self.assertEqual(user_msg["content"][1]["text"], "Hello")
self.assertEqual(user_msg["content"][2]["type"], "image")
# Second message: standard response
self.assertEqual(params["messages"][1]["content"], "Response")
def test_consecutive_same_role_messages_merged(self):
@@ -857,10 +781,10 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
self.assertEqual(len(assistant_msg["content"]), 1)
self.assertEqual(assistant_msg["content"][0]["text"], "I'm doing well, thank you!")
def test_anthropic_specific_messages_included_unchanged(self):
"""Test that LLMSpecificMessage objects with llm='anthropic' are included unchanged (AWS Bedrock uses Anthropic format)."""
# Create anthropic-specific message content (which is what AWS Bedrock uses)
anthropic_message_content = {
def test_llm_specific_message_filtering(self):
"""Test that AWS-specific messages are included and others are filtered out."""
# Create aws-specific message content (which is what AWS Bedrock uses)
aws_message_content = {
"role": "user",
"content": [
{"text": "Hello"},
@@ -869,35 +793,14 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
}
messages = [
LLMSpecificMessage(llm="anthropic", message=anthropic_message_content),
{"role": "assistant", "content": "Hi there!"},
]
# Create context
context = LLMContext(messages=messages)
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Verify the anthropic-specific message is preserved
self.assertEqual(len(params["messages"]), 2)
anthropic_msg = params["messages"][0]
self.assertEqual(anthropic_msg["role"], "user")
self.assertIsInstance(anthropic_msg["content"], list)
self.assertEqual(len(anthropic_msg["content"]), 2)
self.assertEqual(anthropic_msg["content"][0]["text"], "Hello")
self.assertIn("image", anthropic_msg["content"][1])
def test_non_anthropic_specific_messages_filtered_out(self):
"""Test that LLMSpecificMessage objects with llm != 'anthropic' are filtered out."""
messages = [
{"role": "user", "content": "Hello"},
LLMSpecificMessage(
llm="openai", message={"role": "user", "content": "OpenAI specific"}
{"role": "user", "content": "Standard message"},
OpenAILLMAdapter().create_llm_specific_message(
{"role": "user", "content": "OpenAI specific"}
),
LLMSpecificMessage(
llm="google", message={"role": "user", "content": "Google specific"}
GeminiLLMAdapter().create_llm_specific_message(
{"role": "user", "content": "Google specific"}
),
self.adapter.create_llm_specific_message(message=aws_message_content),
{"role": "assistant", "content": "Response"},
]
@@ -907,9 +810,21 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase):
# Get invocation params
params = self.adapter.get_llm_invocation_params(context)
# Should only have the 2 standard messages (openai and google specific filtered out)
# Should only have 2 messages after merging consecutive user messages: merged user + standard response
# (openai and google specific filtered out, standard + aws-specific merged)
self.assertEqual(len(params["messages"]), 2)
self.assertEqual(params["messages"][0]["content"][0]["text"], "Hello")
# First message: merged user message (standard + aws-specific)
user_msg = params["messages"][0]
self.assertEqual(user_msg["role"], "user")
self.assertIsInstance(user_msg["content"], list)
# Should have 3 content blocks: standard text + aws text + aws image
self.assertEqual(len(user_msg["content"]), 3)
self.assertEqual(user_msg["content"][0]["text"], "Standard message")
self.assertEqual(user_msg["content"][1]["text"], "Hello")
self.assertIn("image", user_msg["content"][2])
# Second message: standard response
self.assertEqual(params["messages"][1]["content"][0]["text"], "Response")
def test_consecutive_same_role_messages_merged(self):