diff --git a/src/pipecat/adapters/base_llm_adapter.py b/src/pipecat/adapters/base_llm_adapter.py index 2aae514e1..86e1a5cfa 100644 --- a/src/pipecat/adapters/base_llm_adapter.py +++ b/src/pipecat/adapters/base_llm_adapter.py @@ -16,7 +16,12 @@ from typing import Any, Dict, Generic, List, TypeVar from loguru import logger from pipecat.adapters.schemas.tools_schema import ToolsSchema -from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven +from pipecat.processors.aggregators.llm_context import ( + LLMContext, + LLMContextMessage, + LLMSpecificMessage, + NotGiven, +) # Should be a TypedDict TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any]) @@ -38,6 +43,16 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]): Subclasses must implement provider-specific conversion logic. """ + @property + @abstractmethod + def id_for_llm_specific_messages(self) -> str: + """Get the identifier used in LLMSpecificMessage instances for this LLM provider. + + Returns: + The identifier string. + """ + pass + @abstractmethod def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams: """Get provider-specific LLM invocation parameters from a universal LLM context. @@ -76,6 +91,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]): """ pass + def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage: + """Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext. + + Args: + message: The message content. + + Returns: + A LLMSpecificMessage instance. + """ + return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message) + + def get_messages(self, context: LLMContext) -> List[LLMContextMessage]: + """Get messages from the LLM context, including standard and LLM-specific messages. + + Args: + context: The LLM context containing messages. + + Returns: + List of messages including standard and LLM-specific messages. + """ + return context.get_messages(self.id_for_llm_specific_messages) + def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven: """Convert tools from standard format to provider format. diff --git a/src/pipecat/adapters/services/anthropic_adapter.py b/src/pipecat/adapters/services/anthropic_adapter.py index a98475016..adfe81005 100644 --- a/src/pipecat/adapters/services/anthropic_adapter.py +++ b/src/pipecat/adapters/services/anthropic_adapter.py @@ -42,6 +42,11 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]): to the specific format required by Anthropic's Claude models for function calling. """ + @property + def id_for_llm_specific_messages(self) -> str: + """Get the identifier used in LLMSpecificMessage instances for Anthropic.""" + return "anthropic" + def get_llm_invocation_params( self, context: LLMContext, enable_prompt_caching: bool ) -> AnthropicLLMInvocationParams: @@ -54,7 +59,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]): Returns: Dictionary of parameters for invoking Anthropic's LLM API. """ - messages = self._from_universal_context_messages(self._get_messages(context)) + messages = self._from_universal_context_messages(self.get_messages(context)) return { "system": messages.system, "messages": ( @@ -78,7 +83,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]): List of messages in a format ready for logging about Anthropic. """ # Get messages in Anthropic's format - messages = self._from_universal_context_messages(self._get_messages(context)).messages + messages = self._from_universal_context_messages(self.get_messages(context)).messages # Sanitize messages for logging messages_for_logging = [] @@ -92,9 +97,6 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]): messages_for_logging.append(msg) return messages_for_logging - def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]: - return context.get_messages("anthropic") - @dataclass class ConvertedMessages: """Container for Anthropic-formatted messages converted from universal context.""" diff --git a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py index 8da38e23b..64319d266 100644 --- a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py +++ b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py @@ -31,6 +31,11 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]): specific function-calling format, enabling tool use with Nova Sonic models. """ + @property + def id_for_llm_specific_messages(self) -> str: + """Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic.""" + raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.") + def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams: """Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context. diff --git a/src/pipecat/adapters/services/bedrock_adapter.py b/src/pipecat/adapters/services/bedrock_adapter.py index 2e5c2c62a..681dfb3dc 100644 --- a/src/pipecat/adapters/services/bedrock_adapter.py +++ b/src/pipecat/adapters/services/bedrock_adapter.py @@ -42,6 +42,11 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]): into AWS Bedrock's expected tool format for function calling capabilities. """ + @property + def id_for_llm_specific_messages(self) -> str: + """Get the identifier used in LLMSpecificMessage instances for AWS Bedrock.""" + return "aws" + def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams: """Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context. @@ -51,7 +56,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]): Returns: Dictionary of parameters for invoking AWS Bedrock's LLM API. """ - messages = self._from_universal_context_messages(self._get_messages(context)) + messages = self._from_universal_context_messages(self.get_messages(context)) return { "system": messages.system, "messages": messages.messages, @@ -75,7 +80,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]): List of messages in a format ready for logging about AWS Bedrock. """ # Get messages in Anthropic's format - messages = self._from_universal_context_messages(self._get_messages(context)).messages + messages = self._from_universal_context_messages(self.get_messages(context)).messages # Sanitize messages for logging messages_for_logging = [] @@ -89,9 +94,6 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]): messages_for_logging.append(msg) return messages_for_logging - def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]: - return context.get_messages("anthropic") - @dataclass class ConvertedMessages: """Container for Anthropic-formatted messages converted from universal context.""" diff --git a/src/pipecat/adapters/services/gemini_adapter.py b/src/pipecat/adapters/services/gemini_adapter.py index aa54138b3..63a86e6d2 100644 --- a/src/pipecat/adapters/services/gemini_adapter.py +++ b/src/pipecat/adapters/services/gemini_adapter.py @@ -54,6 +54,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]): - Extracting and sanitizing messages from the LLM context for logging with Gemini. """ + @property + def id_for_llm_specific_messages(self) -> str: + """Get the identifier used in LLMSpecificMessage instances for Google.""" + return "google" + def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams: """Get Gemini-specific LLM invocation parameters from a universal LLM context. @@ -63,7 +68,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]): Returns: Dictionary of parameters for Gemini's API. """ - messages = self._from_universal_context_messages(self._get_messages(context)) + messages = self._from_universal_context_messages(self.get_messages(context)) return { "system_instruction": messages.system_instruction, "messages": messages.messages, @@ -103,7 +108,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]): List of messages in a format ready for logging about Gemini. """ # Get messages in Gemini's format - messages = self._from_universal_context_messages(self._get_messages(context)).messages + messages = self._from_universal_context_messages(self.get_messages(context)).messages # Sanitize messages for logging messages_for_logging = [] @@ -119,9 +124,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]): messages_for_logging.append(obj) return messages_for_logging - def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]: - return context.get_messages("google") - @dataclass class ConvertedMessages: """Container for Google-formatted messages converted from universal context.""" diff --git a/src/pipecat/adapters/services/open_ai_adapter.py b/src/pipecat/adapters/services/open_ai_adapter.py index 8db89eacd..3fa22cefe 100644 --- a/src/pipecat/adapters/services/open_ai_adapter.py +++ b/src/pipecat/adapters/services/open_ai_adapter.py @@ -48,6 +48,11 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]): - Extracting and sanitizing messages from the LLM context for logging about OpenAI. """ + @property + def id_for_llm_specific_messages(self) -> str: + """Get the identifier used in LLMSpecificMessage instances for OpenAI.""" + return "openai" + def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams: """Get OpenAI-specific LLM invocation parameters from a universal LLM context. @@ -58,7 +63,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]): Dictionary of parameters for OpenAI's ChatCompletion API. """ return { - "messages": self._from_universal_context_messages(self._get_messages(context)), + "messages": self._from_universal_context_messages(self.get_messages(context)), # NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN) "tools": self.from_standard_tools(context.tools), "tool_choice": context.tool_choice, @@ -92,7 +97,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]): List of messages in a format ready for logging about OpenAI. """ msgs = [] - for message in self._get_messages(context): + for message in self.get_messages(context): msg = copy.deepcopy(message) if "content" in msg: if isinstance(msg["content"], list): @@ -105,9 +110,6 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]): msgs.append(msg) return msgs - def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]: - return context.get_messages("openai") - def _from_universal_context_messages( self, messages: List[LLMContextMessage] ) -> List[ChatCompletionMessageParam]: diff --git a/src/pipecat/adapters/services/open_ai_realtime_adapter.py b/src/pipecat/adapters/services/open_ai_realtime_adapter.py index d2fd831ba..2ff629e2e 100644 --- a/src/pipecat/adapters/services/open_ai_realtime_adapter.py +++ b/src/pipecat/adapters/services/open_ai_realtime_adapter.py @@ -30,6 +30,11 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter): OpenAI's Realtime API for function calling capabilities. """ + @property + def id_for_llm_specific_messages(self) -> str: + """Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime.""" + raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.") + def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams: """Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context. diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index e309609fe..6f87b95a5 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -44,7 +44,7 @@ from pipecat.frames.frames import ( StartFrame, UserImageRequestFrame, ) -from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage from pipecat.processors.aggregators.llm_response import ( LLMAssistantAggregatorParams, LLMUserAggregatorParams, @@ -195,6 +195,17 @@ class LLMService(AIService): """ return self._adapter + def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage: + """Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext. + + Args: + message: The message content. + + Returns: + A LLMSpecificMessage instance. + """ + return self.get_llm_adapter().create_llm_specific_message(message) + async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]: """Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context. diff --git a/tests/test_get_llm_invocation_params.py b/tests/test_get_llm_invocation_params.py index 10a5ddfd5..77d73201c 100644 --- a/tests/test_get_llm_invocation_params.py +++ b/tests/test_get_llm_invocation_params.py @@ -11,37 +11,33 @@ These tests focus specifically on the "messages" field generation for different For OpenAI adapter: 1. LLMStandardMessage objects are passed through unchanged -2. LLMSpecificMessage objects with llm='openai' are included and their content extracted -3. LLMSpecificMessage objects with llm != 'openai' are filtered out -4. Complex message structures (like multi-part content) are preserved -5. System instructions are preserved throughout messages at any position +2. LLMSpecificMessage objects with llm='openai' are included and others are filtered out +3. Complex message structures (like multi-part content) are preserved +4. System instructions are preserved throughout messages at any position For Gemini adapter: 1. LLMStandardMessage objects are converted to Gemini Content format -2. LLMSpecificMessage objects with llm='google' are included unchanged -3. LLMSpecificMessage objects with llm != 'google' are filtered out -4. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format -5. System messages are extracted as system_instruction (without duplication) -6. Single system instruction is converted to user message when no other messages exist -7. Multiple system instructions: first extracted, later ones converted to user messages +2. LLMSpecificMessage objects with llm='google' are included and others are filtered out +3. Complex message structures (image, audio, multi-text) are converted to appropriate Gemini format +4. System messages are extracted as system_instruction (without duplication) +5. Single system instruction is converted to user message when no other messages exist +6. Multiple system instructions: first extracted, later ones converted to user messages For Anthropic adapter: 1. LLMStandardMessage objects are converted to Anthropic MessageParam format -2. LLMSpecificMessage objects with llm='anthropic' are included unchanged -3. LLMSpecificMessage objects with llm != 'anthropic' are filtered out -4. Complex message structures (image, multi-text) are converted to appropriate Anthropic format -5. System messages: first extracted as system parameter, later ones converted to user messages -6. Consecutive messages with same role are merged into multi-content-block messages -7. Empty text content is converted to "(empty)" +2. LLMSpecificMessage objects with llm='anthropic' are included and others are filtered out +3. Complex message structures (image, multi-text) are converted to appropriate Anthropic format +4. System messages: first extracted as system parameter, later ones converted to user messages +5. Consecutive messages with same role are merged into multi-content-block messages +6. Empty text content is converted to "(empty)" For AWS Bedrock adapter: 1. LLMStandardMessage objects are converted to AWS Bedrock format -2. LLMSpecificMessage objects with llm='anthropic' are included unchanged (uses Anthropic format) -3. LLMSpecificMessage objects with llm != 'anthropic' are filtered out -4. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format -5. System messages: first extracted as system parameter, later ones converted to user messages -6. Consecutive messages with same role are merged into multi-content-block messages -7. Empty text content is converted to "(empty)" +2. LLMSpecificMessage objects with llm='aws' are included and others are filtered out +3. Complex message structures (image, multi-text) are converted to appropriate AWS Bedrock format +4. System messages: first extracted as system parameter, later ones converted to user messages +5. Consecutive messages with same role are merged into multi-content-block messages +6. Empty text content is converted to "(empty)" """ import unittest @@ -89,51 +85,20 @@ class TestOpenAIGetLLMInvocationParams(unittest.TestCase): self.assertEqual(params["messages"][1]["content"], "Hello, how are you?") self.assertEqual(params["messages"][2]["content"], "I'm doing well, thank you for asking!") - def test_openai_specific_messages_included(self): - """Test that LLMSpecificMessage objects with llm='openai' are included.""" - # Create a mix of standard and OpenAI-specific messages - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - LLMSpecificMessage( - llm="openai", message={"role": "user", "content": "OpenAI specific message"} - ), - {"role": "assistant", "content": "Standard response"}, - ] - - # Create context with these messages - context = LLMContext(messages=messages) - - # Get invocation params - params = self.adapter.get_llm_invocation_params(context) - - # Verify all messages are included (OpenAI-specific should be included) - self.assertEqual(len(params["messages"]), 3) - - # First message should be standard - self.assertEqual(params["messages"][0]["content"], "You are a helpful assistant.") - - # Second message should be the OpenAI-specific one - self.assertEqual( - params["messages"][1], {"role": "user", "content": "OpenAI specific message"} - ) - - # Third message should be standard - self.assertEqual(params["messages"][2]["content"], "Standard response") - - def test_non_openai_specific_messages_filtered_out(self): - """Test that LLMSpecificMessage objects with llm != 'openai' are filtered out.""" + def test_llm_specific_message_filtering(self): + """Test that OpenAI-specific messages are included and others are filtered out.""" # Create messages with different LLM-specific ones messages = [ {"role": "system", "content": "You are a helpful assistant."}, - LLMSpecificMessage( - llm="anthropic", message={"role": "user", "content": "Anthropic specific message"} + AnthropicLLMAdapter().create_llm_specific_message( + {"role": "user", "content": "Anthropic specific message"} ), - LLMSpecificMessage( - llm="gemini", message={"role": "user", "content": "Gemini specific message"} + GeminiLLMAdapter().create_llm_specific_message( + {"role": "user", "content": "Gemini specific message"} ), {"role": "user", "content": "Standard user message"}, - LLMSpecificMessage( - llm="openai", message={"role": "assistant", "content": "OpenAI specific response"} + self.adapter.create_llm_specific_message( + {"role": "assistant", "content": "OpenAI specific response"} ), ] @@ -291,53 +256,20 @@ class TestGeminiGetLLMInvocationParams(unittest.TestCase): self.assertEqual(len(model_msg.parts), 1) self.assertEqual(model_msg.parts[0].text, "I'm doing well, thank you for asking!") - def test_gemini_specific_messages_included(self): - """Test that LLMSpecificMessage objects with llm='google' are included unchanged.""" - # Create a Gemini-specific message - gemini_message = Content(role="user", parts=[Part(text="Gemini specific message")]) - - # Create a mix of standard and Gemini-specific messages - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - LLMSpecificMessage(llm="google", message=gemini_message), - {"role": "assistant", "content": "Standard response"}, - ] - - # Create context with these messages - context = LLMContext(messages=messages) - - # Get invocation params - params = self.adapter.get_llm_invocation_params(context) - - # Verify system instruction - self.assertEqual(params["system_instruction"], "You are a helpful assistant.") - - # Verify messages (2 total: gemini-specific user + converted model) - self.assertEqual(len(params["messages"]), 2) - - # First message should be the Gemini-specific one (unchanged) - self.assertEqual(params["messages"][0], gemini_message) - self.assertEqual(params["messages"][0].parts[0].text, "Gemini specific message") - - # Second message should be converted standard message - self.assertEqual(params["messages"][1].role, "model") - self.assertEqual(params["messages"][1].parts[0].text, "Standard response") - - def test_non_gemini_specific_messages_filtered_out(self): - """Test that LLMSpecificMessage objects with llm != 'google' are filtered out.""" + def test_llm_specific_message_filtering(self): + """Test that Gemini-specific messages are included and others are filtered out.""" # Create messages with different LLM-specific ones messages = [ {"role": "system", "content": "You are a helpful assistant."}, - LLMSpecificMessage( - llm="openai", message={"role": "user", "content": "OpenAI specific message"} + OpenAILLMAdapter().create_llm_specific_message( + {"role": "user", "content": "OpenAI specific message"} ), - LLMSpecificMessage( - llm="anthropic", message={"role": "user", "content": "Anthropic specific message"} + AnthropicLLMAdapter().create_llm_specific_message( + {"role": "user", "content": "Anthropic specific message"} ), {"role": "user", "content": "Standard user message"}, - LLMSpecificMessage( - llm="google", - message=Content(role="model", parts=[Part(text="Gemini specific response")]), + self.adapter.create_llm_specific_message( + Content(role="model", parts=[Part(text="Gemini specific response")]), ), ] @@ -584,8 +516,8 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase): self.assertEqual(assistant_msg["role"], "assistant") self.assertEqual(assistant_msg["content"], "I'm doing well, thank you!") - def test_anthropic_specific_messages_included_unchanged(self): - """Test that LLMSpecificMessage objects with llm='anthropic' are included unchanged.""" + def test_llm_specific_message_filtering(self): + """Test that Anthropic-specific messages are included and others are filtered out.""" # Create anthropic-specific message content anthropic_message_content = { "role": "user", @@ -599,36 +531,14 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase): } messages = [ - LLMSpecificMessage(llm="anthropic", message=anthropic_message_content), - {"role": "assistant", "content": "Hi there!"}, - ] - - # Create context - context = LLMContext(messages=messages) - - # Get invocation params - params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False) - - # Verify the anthropic-specific message is preserved - self.assertEqual(len(params["messages"]), 2) - anthropic_msg = params["messages"][0] - self.assertEqual(anthropic_msg["role"], "user") - self.assertIsInstance(anthropic_msg["content"], list) - self.assertEqual(len(anthropic_msg["content"]), 2) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], "Hello") - self.assertEqual(anthropic_msg["content"][1]["type"], "image") - - def test_non_anthropic_specific_messages_filtered_out(self): - """Test that LLMSpecificMessage objects with llm != 'anthropic' are filtered out.""" - messages = [ - {"role": "user", "content": "Hello"}, - LLMSpecificMessage( - llm="openai", message={"role": "user", "content": "OpenAI specific"} + {"role": "user", "content": "Standard message"}, + OpenAILLMAdapter().create_llm_specific_message( + {"role": "user", "content": "OpenAI specific"} ), - LLMSpecificMessage( - llm="google", message={"role": "user", "content": "Google specific"} + GeminiLLMAdapter().create_llm_specific_message( + {"role": "user", "content": "Google specific"} ), + self.adapter.create_llm_specific_message(anthropic_message_content), {"role": "assistant", "content": "Response"}, ] @@ -638,9 +548,23 @@ class TestAnthropicGetLLMInvocationParams(unittest.TestCase): # Get invocation params params = self.adapter.get_llm_invocation_params(context, enable_prompt_caching=False) - # Should only have the 2 standard messages (openai and google specific filtered out) + # Should only have 2 messages after merging consecutive user messages: merged user + standard response + # (openai and google specific filtered out, standard + anthropic-specific merged) self.assertEqual(len(params["messages"]), 2) - self.assertEqual(params["messages"][0]["content"], "Hello") + + # First message: merged user message (standard + anthropic-specific) + user_msg = params["messages"][0] + self.assertEqual(user_msg["role"], "user") + self.assertIsInstance(user_msg["content"], list) + # Should have 3 content blocks: standard text + anthropic text + anthropic image + self.assertEqual(len(user_msg["content"]), 3) + self.assertEqual(user_msg["content"][0]["type"], "text") + self.assertEqual(user_msg["content"][0]["text"], "Standard message") + self.assertEqual(user_msg["content"][1]["type"], "text") + self.assertEqual(user_msg["content"][1]["text"], "Hello") + self.assertEqual(user_msg["content"][2]["type"], "image") + + # Second message: standard response self.assertEqual(params["messages"][1]["content"], "Response") def test_consecutive_same_role_messages_merged(self): @@ -857,10 +781,10 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase): self.assertEqual(len(assistant_msg["content"]), 1) self.assertEqual(assistant_msg["content"][0]["text"], "I'm doing well, thank you!") - def test_anthropic_specific_messages_included_unchanged(self): - """Test that LLMSpecificMessage objects with llm='anthropic' are included unchanged (AWS Bedrock uses Anthropic format).""" - # Create anthropic-specific message content (which is what AWS Bedrock uses) - anthropic_message_content = { + def test_llm_specific_message_filtering(self): + """Test that AWS-specific messages are included and others are filtered out.""" + # Create aws-specific message content (which is what AWS Bedrock uses) + aws_message_content = { "role": "user", "content": [ {"text": "Hello"}, @@ -869,35 +793,14 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase): } messages = [ - LLMSpecificMessage(llm="anthropic", message=anthropic_message_content), - {"role": "assistant", "content": "Hi there!"}, - ] - - # Create context - context = LLMContext(messages=messages) - - # Get invocation params - params = self.adapter.get_llm_invocation_params(context) - - # Verify the anthropic-specific message is preserved - self.assertEqual(len(params["messages"]), 2) - anthropic_msg = params["messages"][0] - self.assertEqual(anthropic_msg["role"], "user") - self.assertIsInstance(anthropic_msg["content"], list) - self.assertEqual(len(anthropic_msg["content"]), 2) - self.assertEqual(anthropic_msg["content"][0]["text"], "Hello") - self.assertIn("image", anthropic_msg["content"][1]) - - def test_non_anthropic_specific_messages_filtered_out(self): - """Test that LLMSpecificMessage objects with llm != 'anthropic' are filtered out.""" - messages = [ - {"role": "user", "content": "Hello"}, - LLMSpecificMessage( - llm="openai", message={"role": "user", "content": "OpenAI specific"} + {"role": "user", "content": "Standard message"}, + OpenAILLMAdapter().create_llm_specific_message( + {"role": "user", "content": "OpenAI specific"} ), - LLMSpecificMessage( - llm="google", message={"role": "user", "content": "Google specific"} + GeminiLLMAdapter().create_llm_specific_message( + {"role": "user", "content": "Google specific"} ), + self.adapter.create_llm_specific_message(message=aws_message_content), {"role": "assistant", "content": "Response"}, ] @@ -907,9 +810,21 @@ class TestAWSBedrockGetLLMInvocationParams(unittest.TestCase): # Get invocation params params = self.adapter.get_llm_invocation_params(context) - # Should only have the 2 standard messages (openai and google specific filtered out) + # Should only have 2 messages after merging consecutive user messages: merged user + standard response + # (openai and google specific filtered out, standard + aws-specific merged) self.assertEqual(len(params["messages"]), 2) - self.assertEqual(params["messages"][0]["content"][0]["text"], "Hello") + + # First message: merged user message (standard + aws-specific) + user_msg = params["messages"][0] + self.assertEqual(user_msg["role"], "user") + self.assertIsInstance(user_msg["content"], list) + # Should have 3 content blocks: standard text + aws text + aws image + self.assertEqual(len(user_msg["content"]), 3) + self.assertEqual(user_msg["content"][0]["text"], "Standard message") + self.assertEqual(user_msg["content"][1]["text"], "Hello") + self.assertIn("image", user_msg["content"][2]) + + # Second message: standard response self.assertEqual(params["messages"][1]["content"][0]["text"], "Response") def test_consecutive_same_role_messages_merged(self):