From 4d9873b613987e5bd862a808386bc8bb2a256b58 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 25 Sep 2025 09:46:33 -0400 Subject: [PATCH] Update `AWSNovaSonicLLMService` to work with `LLMContext` and `LLMContextAggregatorPair` --- CHANGELOG.md | 12 + .../20a-persistent-context-openai.py | 4 +- .../20c-persistent-context-anthropic.py | 4 +- .../20d-persistent-context-gemini.py | 5 +- .../20e-persistent-context-aws-nova-sonic.py | 6 +- examples/foundational/40-aws-nova-sonic.py | 9 +- .../services/aws_nova_sonic_adapter.py | 122 +++++- .../processors/aggregators/llm_context.py | 41 +- src/pipecat/services/aws_nova_sonic/aws.py | 269 ++++++++----- .../services/aws_nova_sonic/context.py | 367 ------------------ tests/test_llm_context.py | 208 ++++++++++ 11 files changed, 571 insertions(+), 476 deletions(-) delete mode 100644 src/pipecat/services/aws_nova_sonic/context.py create mode 100644 tests/test_llm_context.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c928028f..16c7710c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Expanded support for universal `LLMContext` to `AWSNovaSonicLLMService`. + As a reminder, the context-setup pattern when using `LLMContext` is: + + ```python + context = LLMContext(messages, tools) + context_aggregator = LLMContextAggregatorPair(context) + ``` + + (Note that even though `AWSNovaSonicLLMService` now supports the universal + `LLMContext`, it is not meant to be swapped out for another LLM service at + runtime.) + - Include OpenAI-based LLM services cached tokens to `MetricsFrame`. ## Fixed diff --git a/examples/foundational/20a-persistent-context-openai.py b/examples/foundational/20a-persistent-context-openai.py index 93c1fa438..a3f9f005a 100644 --- a/examples/foundational/20a-persistent-context-openai.py +++ b/examples/foundational/20a-persistent-context-openai.py @@ -67,11 +67,11 @@ async def save_conversation(params: FunctionCallParams): timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") filename = f"{BASE_FILENAME}{timestamp}.json" logger.debug( - f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}" + f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}" ) try: with open(filename, "w") as file: - messages = params.context.get_messages() + messages = params.context.get_messages_for_persistent_storage() # remove the last message, which is the instruction we just gave to save the conversation messages.pop() json.dump(messages, file, indent=2) diff --git a/examples/foundational/20c-persistent-context-anthropic.py b/examples/foundational/20c-persistent-context-anthropic.py index 411a976b8..7a0953290 100644 --- a/examples/foundational/20c-persistent-context-anthropic.py +++ b/examples/foundational/20c-persistent-context-anthropic.py @@ -68,12 +68,12 @@ async def save_conversation(params: FunctionCallParams): timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") filename = f"{BASE_FILENAME}{timestamp}.json" logger.debug( - f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}" + f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}" ) try: with open(filename, "w") as file: # todo: extract 'system' into the first message in the list - messages = params.context.get_messages() + messages = params.context.get_messages_for_persistent_storage() # remove the last message, which is the instruction we just gave to save the conversation messages.pop() json.dump(messages, file, indent=2) diff --git a/examples/foundational/20d-persistent-context-gemini.py b/examples/foundational/20d-persistent-context-gemini.py index 8dad8148d..0e52c0308 100644 --- a/examples/foundational/20d-persistent-context-gemini.py +++ b/examples/foundational/20d-persistent-context-gemini.py @@ -86,12 +86,11 @@ async def save_conversation(params: FunctionCallParams): timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") filename = f"{BASE_FILENAME}{timestamp}.json" logger.debug( - f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}" + f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}" ) try: with open(filename, "w") as file: - # todo: extract 'system' into the first message in the list - messages = params.context.get_messages() + messages = params.context.get_messages_for_persistent_storage() # remove the last message (the instruction to save the context) messages.pop() json.dump(messages, file, indent=2) diff --git a/examples/foundational/20e-persistent-context-aws-nova-sonic.py b/examples/foundational/20e-persistent-context-aws-nova-sonic.py index bd3d9d545..9f364ef6c 100644 --- a/examples/foundational/20e-persistent-context-aws-nova-sonic.py +++ b/examples/foundational/20e-persistent-context-aws-nova-sonic.py @@ -20,6 +20,8 @@ from pipecat.frames.frames import LLMRunFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport @@ -223,13 +225,13 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames) llm.register_function("load_conversation", load_conversation) - context = OpenAILLMContext( + context = LLMContext( messages=[ {"role": "system", "content": f"{system_instruction}"}, ], tools=tools, ) - context_aggregator = llm.create_context_aggregator(context) + context_aggregator = LLMContextAggregatorPair(context) pipeline = Pipeline( [ diff --git a/examples/foundational/40-aws-nova-sonic.py b/examples/foundational/40-aws-nova-sonic.py index de7bbf638..c2aea298b 100644 --- a/examples/foundational/40-aws-nova-sonic.py +++ b/examples/foundational/40-aws-nova-sonic.py @@ -18,7 +18,8 @@ from pipecat.frames.frames import LLMRunFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport from pipecat.services.aws_nova_sonic import AWSNovaSonicLLMService @@ -119,9 +120,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): llm.register_function("get_current_weather", fetch_weather_from_api) # Set up context and context management. - # AWSNovaSonicService will adapt OpenAI LLM context objects with standard message format to - # what's expected by Nova Sonic. - context = OpenAILLMContext( + context = LLMContext( messages=[ {"role": "system", "content": f"{system_instruction}"}, { @@ -131,7 +130,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): ], tools=tools, ) - context_aggregator = llm.create_context_aggregator(context) + context_aggregator = LLMContextAggregatorPair(context) # Build the pipeline pipeline = Pipeline( diff --git a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py index 64319d266..60f12798b 100644 --- a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py +++ b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py @@ -6,13 +6,47 @@ """AWS Nova Sonic LLM adapter for Pipecat.""" +import copy import json -from typing import Any, Dict, List, TypedDict +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, TypedDict + +from loguru import logger from pipecat.adapters.base_llm_adapter import BaseLLMAdapter from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import ToolsSchema -from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage + + +class Role(Enum): + """Roles supported in AWS Nova Sonic conversations. + + Parameters: + SYSTEM: System-level messages (not used in conversation history). + USER: Messages sent by the user. + ASSISTANT: Messages sent by the assistant. + TOOL: Messages sent by tools (not used in conversation history). + """ + + SYSTEM = "SYSTEM" + USER = "USER" + ASSISTANT = "ASSISTANT" + TOOL = "TOOL" + + +@dataclass +class AWSNovaSonicConversationHistoryMessage: + """A single message in AWS Nova Sonic conversation history. + + Parameters: + role: The role of the message sender (USER or ASSISTANT only). + text: The text content of the message. + """ + + role: Role # only USER and ASSISTANT + text: str class AWSNovaSonicLLMInvocationParams(TypedDict): @@ -21,7 +55,9 @@ class AWSNovaSonicLLMInvocationParams(TypedDict): This is a placeholder until support for universal LLMContext machinery is added for AWS Nova Sonic. """ - pass + system_instruction: Optional[str] + messages: List[AWSNovaSonicConversationHistoryMessage] + tools: List[Dict[str, Any]] class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]): @@ -34,7 +70,7 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]): @property def id_for_llm_specific_messages(self) -> str: """Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic.""" - raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.") + return "aws-nova-sonic" def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams: """Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context. @@ -47,7 +83,13 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]): Returns: Dictionary of parameters for invoking AWS Nova Sonic's LLM API. """ - raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.") + messages = self._from_universal_context_messages(self.get_messages(context)) + return { + "system_instruction": messages.system_instruction, + "messages": messages.messages, + # NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN) + "tools": self.from_standard_tools(context.tools) or [], + } def get_messages_for_logging(self, context) -> List[Dict[str, Any]]: """Get messages from a universal LLM context in a format ready for logging about AWS Nova Sonic. @@ -62,7 +104,75 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]): Returns: List of messages in a format ready for logging about AWS Nova Sonic. """ - raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.") + return self._from_universal_context_messages(self.get_messages(context)).messages + + @dataclass + class ConvertedMessages: + """Container for Google-formatted messages converted from universal context.""" + + messages: List[AWSNovaSonicConversationHistoryMessage] + system_instruction: Optional[str] = None + + def _from_universal_context_messages( + self, universal_context_messages: List[LLMContextMessage] + ) -> ConvertedMessages: + system_instruction = None + messages = [] + + # Bail if there are no messages + if not universal_context_messages: + return self.ConvertedMessages() + + universal_context_messages = copy.deepcopy(universal_context_messages) + + # If we have a "system" message as our first message, let's pull that out into "instruction" + if universal_context_messages[0].get("role") == "system": + system = universal_context_messages.pop(0) + content = system.get("content") + if isinstance(content, str): + system_instruction = content + elif isinstance(content, list): + system_instruction = content[0].get("text") + if system_instruction: + self._system_instruction = system_instruction + + # Process remaining messages to fill out conversation history. + # Nova Sonic supports "user" and "assistant" messages in history. + for universal_context_message in universal_context_messages: + message = self._from_universal_context_message(universal_context_message) + if message: + messages.append(message) + + return self.ConvertedMessages(messages=messages, system_instruction=system_instruction) + + def _from_universal_context_message(self, message) -> AWSNovaSonicConversationHistoryMessage: + """Convert standard message format to Nova Sonic format. + + Args: + message: Standard message dictionary to convert. + + Returns: + Nova Sonic conversation history message, or None if not convertible. + """ + role = message.get("role") + if message.get("role") == "user" or message.get("role") == "assistant": + content = message.get("content") + if isinstance(message.get("content"), list): + content = "" + for c in message.get("content"): + if c.get("type") == "text": + content += " " + c.get("text") + else: + logger.error( + f"Unhandled content type in context message: {c.get('type')} - {message}" + ) + # There won't be content if this is an assistant tool call entry. + # We're ignoring those since they can't be loaded into AWS Nova Sonic conversation + # history + if content: + return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content) + # NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova + # Sonic conversation history @staticmethod def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]: diff --git a/src/pipecat/processors/aggregators/llm_context.py b/src/pipecat/processors/aggregators/llm_context.py index 8b677cf02..e98bc743e 100644 --- a/src/pipecat/processors/aggregators/llm_context.py +++ b/src/pipecat/processors/aggregators/llm_context.py @@ -15,9 +15,10 @@ service-specific adapter. """ import base64 +import copy import io from dataclasses import dataclass -from typing import Any, List, Optional, TypeAlias, Union +from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union from loguru import logger from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN @@ -31,6 +32,9 @@ from PIL import Image from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.frames.frames import AudioRawFrame +if TYPE_CHECKING: + from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext + # "Re-export" types from OpenAI that we're using as universal context types. # NOTE: if universal message types need to someday diverge from OpenAI's, we # should consider managing our own definitions. But we should do so carefully, @@ -65,6 +69,26 @@ class LLMContext: and content formatting. """ + @staticmethod + def from_openai_context(openai_context: "OpenAILLMContext") -> "LLMContext": + """Create a universal LLM context from an OpenAI-specific context. + + NOTE: this should only be used internally, for facilitating migration + from OpenAILLMContext to LLMContext. New user code should use + LLMContext directly. + + Args: + openai_context: The OpenAI LLM context to convert. + + Returns: + New LLMContext instance with converted messages and settings. + """ + return LLMContext( + messages=openai_context.get_messages(), + tools=openai_context.tools, + tool_choice=openai_context.tool_choice, + ) + def __init__( self, messages: Optional[List[LLMContextMessage]] = None, @@ -107,6 +131,21 @@ class LLMContext: ) return filtered_messages + def get_messages_for_persistent_storage( + self, system_instruction: Optional[str] = None + ) -> List[LLMContextMessage]: + """Get messages formatted for persistent storage. + + Args: + system_instruction: Optional system instruction to ensure is + included as the first message in the returned list, if not + already present. + """ + messages = copy.deepcopy(self.get_messages()) + if system_instruction and (not messages or messages[0].get("role") != "system"): + messages.insert(0, {"role": "system", "content": system_instruction}) + return messages + @property def tools(self) -> ToolsSchema | NotGiven: """Get the tools list. diff --git a/src/pipecat/services/aws_nova_sonic/aws.py b/src/pipecat/services/aws_nova_sonic/aws.py index 8c5a23f3c..8a84684cf 100644 --- a/src/pipecat/services/aws_nova_sonic/aws.py +++ b/src/pipecat/services/aws_nova_sonic/aws.py @@ -25,7 +25,7 @@ from loguru import logger from pydantic import BaseModel, Field from pipecat.adapters.schemas.tools_schema import ToolsSchema -from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter +from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter, Role from pipecat.frames.frames import ( BotStoppedSpeakingFrame, CancelFrame, @@ -33,36 +33,36 @@ from pipecat.frames.frames import ( Frame, FunctionCallFromLLM, InputAudioRawFrame, - InterimTranscriptionFrame, + InterruptionFrame, LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, - LLMTextFrame, StartFrame, TranscriptionFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, TTSTextFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, ) +from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.aggregators.llm_response import ( LLMAssistantAggregatorParams, LLMUserAggregatorParams, ) +from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.aws_nova_sonic.context import ( - AWSNovaSonicAssistantContextAggregator, - AWSNovaSonicContextAggregatorPair, - AWSNovaSonicLLMContext, - AWSNovaSonicUserContextAggregator, - Role, -) -from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame from pipecat.services.llm_service import LLMService +from pipecat.services.openai.llm import ( + OpenAIAssistantContextAggregator, + OpenAIContextAggregatorPair, + OpenAIUserContextAggregator, +) from pipecat.utils.time import time_now_iso8601 try: @@ -231,7 +231,7 @@ class AWSNovaSonicLLMService(LLMService): self._system_instruction = system_instruction self._tools = tools self._send_transcription_frames = send_transcription_frames - self._context: Optional[AWSNovaSonicLLMContext] = None + self._context: Optional[LLMContext] = None self._stream: Optional[ DuplexEventStream[ InvokeModelWithBidirectionalStreamInput, @@ -244,12 +244,17 @@ class AWSNovaSonicLLMService(LLMService): self._input_audio_content_name: Optional[str] = None self._content_being_received: Optional[CurrentContent] = None self._assistant_is_responding = False + self._needs_re_push_assistant_text = False self._ready_to_send_context = False self._handling_bot_stopped_speaking = False self._triggering_assistant_response = False + self._waiting_for_trigger_transcription = False self._disconnecting = False self._connected_time: Optional[float] = None self._wants_connection = False + self._user_text_buffer = "" + self._assistant_text_buffer = "" + self._completed_tool_calls = set() file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav") with wave.open(file_path.open("rb"), "rb") as wav_file: @@ -302,12 +307,12 @@ class AWSNovaSonicLLMService(LLMService): logger.debug("Resetting conversation") await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=False) - # Carry over previous context through disconnect + # Grab context to carry through disconnect/reconnect context = self._context - await self._disconnect() - self._context = context + await self._disconnect() await self._start_connecting() + await self._handle_context(context) # # frame processing @@ -322,28 +327,35 @@ class AWSNovaSonicLLMService(LLMService): """ await super().process_frame(frame, direction) - if isinstance(frame, OpenAILLMContextFrame): - await self._handle_context(frame.context) - elif isinstance(frame, LLMContextFrame): - raise NotImplementedError( - "Universal LLMContext is not yet supported for AWS Nova Sonic." + if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): + context = ( + frame.context + if isinstance(frame, LLMContextFrame) + else LLMContext.from_openai_context(frame.context) ) + await self._handle_context(context) elif isinstance(frame, InputAudioRawFrame): await self._handle_input_audio_frame(frame) elif isinstance(frame, BotStoppedSpeakingFrame): await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=True) - elif isinstance(frame, AWSNovaSonicFunctionCallResultFrame): - await self._handle_function_call_result(frame) + elif isinstance(frame, InterruptionFrame): + await self._handle_interruption_frame() await self.push_frame(frame, direction) - async def _handle_context(self, context: OpenAILLMContext): + async def _handle_context(self, context: LLMContext): + if self._disconnecting: + return + if not self._context: - # We got our initial context - try to finish connecting - self._context = AWSNovaSonicLLMContext.upgrade_to_nova_sonic( - context, self._system_instruction - ) + # We got our initial context + # Try to finish connecting + self._context = context await self._finish_connecting_if_context_available() + else: + # We got an updated context + # Send results for any newly-completed function calls + await self._process_completed_function_calls(send_new_results=True) async def _handle_input_audio_frame(self, frame: InputAudioRawFrame): # Wait until we're done sending the assistant response trigger audio before sending audio @@ -393,9 +405,9 @@ class AWSNovaSonicLLMService(LLMService): else: await finalize_assistant_response() - async def _handle_function_call_result(self, frame: AWSNovaSonicFunctionCallResultFrame): - result = frame.result_frame - await self._send_tool_result(tool_call_id=result.tool_call_id, result=result.result) + async def _handle_interruption_frame(self): + if self._assistant_is_responding: + self._needs_re_push_assistant_text = True # # LLM communication: lifecycle @@ -431,6 +443,17 @@ class AWSNovaSonicLLMService(LLMService): logger.error(f"{self} initialization error: {e}") self._disconnect() + async def _process_completed_function_calls(self, send_new_results: bool): + # Check for set of completed function calls in the context + for message in self._context.get_messages(): + if message.get("role") and message.get("content") != "IN_PROGRESS": + tool_call_id = message.get("tool_call_id") + if tool_call_id and tool_call_id not in self._completed_tool_calls: + # Found a newly-completed function call - send the result to the service + if send_new_results: + await self._send_tool_result(tool_call_id, message.get("content")) + self._completed_tool_calls.add(tool_call_id) + async def _finish_connecting_if_context_available(self): # We can only finish connecting once we've gotten our initial context and we're ready to # send it @@ -439,30 +462,38 @@ class AWSNovaSonicLLMService(LLMService): logger.info("Finishing connecting (setting up session)...") + # Initialize our bookkeeping of already-completed tool calls in the + # context + await self._process_completed_function_calls(send_new_results=False) + # Read context - history = self._context.get_messages_for_initializing_history() + adapter: AWSNovaSonicLLMAdapter = self.get_llm_adapter() + llm_connection_params = adapter.get_llm_invocation_params(self._context) # Send prompt start event, specifying tools. # Tools from context take priority over self._tools. tools = ( - self._context.tools - if self._context.tools - else self.get_llm_adapter().from_standard_tools(self._tools) + llm_connection_params["tools"] + if llm_connection_params["tools"] + else adapter.from_standard_tools(self._tools) ) logger.debug(f"Using tools: {tools}") await self._send_prompt_start_event(tools) # Send system instruction. # Instruction from context takes priority over self._system_instruction. - # (NOTE: this prioritizing occurred automatically behind the scenes: the context was - # initialized with self._system_instruction and then updated itself from its messages when - # get_messages_for_initializing_history() was called). - logger.debug(f"Using system instruction: {history.system_instruction}") - if history.system_instruction: - await self._send_text_event(text=history.system_instruction, role=Role.SYSTEM) + system_instruction = ( + llm_connection_params["system_instruction"] + if llm_connection_params["system_instruction"] + else self._system_instruction + ) + logger.debug(f"Using system instruction: {system_instruction}") + if system_instruction: + await self._send_text_event(text=system_instruction, role=Role.SYSTEM) # Send conversation history - for message in history.messages: + for message in llm_connection_params["messages"]: + # logger.debug(f"Seeding conversation history with message: {message}") await self._send_text_event(text=message.text, role=message.role) # Start audio input @@ -492,9 +523,12 @@ class AWSNovaSonicLLMService(LLMService): await self._send_session_end_events() self._client = None + # Clean up context + self._context = None + # Clean up stream if self._stream: - await self._stream.input_stream.close() + await self._stream.close() self._stream = None # NOTE: see explanation of HACK, below @@ -510,15 +544,23 @@ class AWSNovaSonicLLMService(LLMService): self._receive_task = None # Reset remaining connection-specific state + # Should be all private state except: + # - _wants_connection + # - _assistant_response_trigger_audio self._prompt_name = None self._input_audio_content_name = None self._content_being_received = None self._assistant_is_responding = False + self._needs_re_push_assistant_text = False self._ready_to_send_context = False self._handling_bot_stopped_speaking = False self._triggering_assistant_response = False + self._waiting_for_trigger_transcription = False self._disconnecting = False self._connected_time = None + self._user_text_buffer = "" + self._assistant_text_buffer = "" + self._completed_tool_calls = set() logger.info("Finished disconnecting") except Exception as e: @@ -830,6 +872,10 @@ class AWSNovaSonicLLMService(LLMService): # Handle the LLM completion ending await self._handle_completion_end_event(event_json) except Exception as e: + if self._disconnecting: + # Errors are kind of expected while disconnecting, so just + # ignore them and do nothing + return logger.error(f"{self} error processing responses: {e}") if self._wants_connection: await self.reset_conversation() @@ -960,7 +1006,7 @@ class AWSNovaSonicLLMService(LLMService): async def _report_assistant_response_started(self): logger.debug("Assistant response started") - # Report that the assistant has started their response. + # Report the start of the assistant response. await self.push_frame(LLMFullResponseStartFrame()) # Report that equivalent of TTS (this is a speech-to-speech model) started @@ -972,23 +1018,16 @@ class AWSNovaSonicLLMService(LLMService): logger.debug(f"Assistant response text added: {text}") - # Report some text added to the ongoing assistant response - await self.push_frame(LLMTextFrame(text)) - - # Report some text added to the *equivalent* of TTS (this is a speech-to-speech model) + # Report the text of the assistant response. await self.push_frame(TTSTextFrame(text)) - # TODO: this is a (hopefully temporary) HACK. Here we directly manipulate the context rather - # than relying on the frames pushed to the assistant context aggregator. The pattern of - # receiving full-sentence text after the assistant has spoken does not easily fit with the - # Pipecat expectation of chunks of text streaming in while the assistant is speaking. - # Interruption handling was especially challenging. Rather than spend days trying to fit a - # square peg in a round hole, I decided on this hack for the time being. We can most cleanly - # abandon this hack if/when AWS Nova Sonic implements streaming smaller text chunks - # interspersed with audio. Note that when we move away from this hack, we need to make sure - # that on an interruption we avoid sending LLMFullResponseEndFrame, which gets the - # LLMAssistantContextAggregator into a bad state. - self._context.buffer_assistant_text(text) + # HACK: here we're also buffering the assistant text ourselves as a + # backup rather than relying solely on the assistant context aggregator + # to do it, because the text arrives from Nova Sonic only after all the + # assistant audio frames have been pushed, meaning that if an + # interruption frame were to arrive we would lose all of it (the text + # frames sitting in the queue would be wiped). + self._assistant_text_buffer += text async def _report_assistant_response_ended(self): if not self._context: # should never happen @@ -996,14 +1035,25 @@ class AWSNovaSonicLLMService(LLMService): logger.debug("Assistant response ended") - # Report that the assistant has finished their response. + # If an interruption frame arrived while the assistant was responding + # we probably lost all of the assistant text (see HACK, above), so + # re-push it downstream to the aggregator now. + if self._needs_re_push_assistant_text: + # We also need to re-push the LLMFullResponseStartFrame since the + # TTSTextFrame would be ignored otherwise (the interruption frame + # would have cleared the assistant aggregator state). + await self.push_frame(LLMFullResponseStartFrame()) + await self.push_frame(TTSTextFrame(self._assistant_text_buffer)) + self._needs_re_push_assistant_text = False + + # Report the end of the assistant response. await self.push_frame(LLMFullResponseEndFrame()) # Report that equivalent of TTS (this is a speech-to-speech model) stopped. await self.push_frame(TTSStoppedFrame()) - # For an explanation of this hack, see _report_assistant_response_text_added. - self._context.flush_aggregated_assistant_text() + # Clear out the buffered assistant text + self._assistant_text_buffer = "" # # user transcription reporting @@ -1020,33 +1070,71 @@ class AWSNovaSonicLLMService(LLMService): logger.debug(f"User transcription text added: {text}") - # Manually add new user transcription text to context. - # We can't rely on the user context aggregator to do this since it's upstream from the LLM. - self._context.buffer_user_text(text) - - # Report that some new user transcription text is available. - if self._send_transcription_frames: - await self.push_frame( - InterimTranscriptionFrame(text=text, user_id="", timestamp=time_now_iso8601()) - ) + # HACK: here we're buffering the user text ourselves rather than + # relying on the upstream user context aggregator to do it, because the + # text arrives in fairly large chunks spaced fairly far apart in time. + # That means the user text would be split between different messages in + # context. Even if we sent placeholder InterimTranscriptionFrames in + # between each TranscriptionFrame to tell the aggregator to hold off on + # finalizing the user message, the aggregator would likely get the last + # chunk too late. + self._user_text_buffer += f" {text}" if self._user_text_buffer else text async def _report_user_transcription_ended(self): if not self._context: # should never happen return - # Manually add user transcription to context (if any has been buffered). - # We can't rely on the user context aggregator to do this since it's upstream from the LLM. - transcription = self._context.flush_aggregated_user_text() - - if not transcription: - return - logger.debug(f"User transcription ended") - if self._send_transcription_frames: - await self.push_frame( - TranscriptionFrame(text=transcription, user_id="", timestamp=time_now_iso8601()) + # Report to the upstream user context aggregator that some new user + # transcription text is available. + + # HACK: Check if this transcription was triggered by our own + # assistant response trigger. If so, we need to wrap it with + # UserStarted/StoppedSpeakingFrames; otherwise the user aggregator + # would fire an EmulatedUserStartedSpeakingFrame, which would + # trigger an interruption, which would prevent us from writing the + # assistant response to context. + # + # Sending an EmulateUserStartedSpeakingFrame ourselves doesn't + # work: it just causes the interruption we're trying to avoid. + # + # Setting enable_emulated_vad_interruptions also doesn't work: at + # the time the user aggregator receives the TranscriptionFrame, it + # doesn't yet know the assistant has started responding, so it + # doesn't know that emulating the user starting to speak would + # cause an interruption. + should_wrap_in_user_started_stopped_speaking_frames = ( + self._waiting_for_trigger_transcription + and self._user_text_buffer.strip().lower() == "ready" + ) + + # Start wrapping the upstream transcription in UserStarted/StoppedSpeakingFrames if needed + if should_wrap_in_user_started_stopped_speaking_frames: + logger.debug( + "Wrapping assistant response trigger transcription with upstream UserStarted/StoppedSpeakingFrames" ) + await self.push_frame(UserStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM) + + # Send the transcription upstream for the user context aggregator + frame = TranscriptionFrame( + text=self._user_text_buffer, user_id="", timestamp=time_now_iso8601() + ) + await self.push_frame(frame, direction=FrameDirection.UPSTREAM) + + # Finish wrapping the upstream transcription in UserStarted/StoppedSpeakingFrames if needed + if should_wrap_in_user_started_stopped_speaking_frames: + await self.push_frame(UserStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM) + + # Also send the transcription downstream if requested + if self._send_transcription_frames: + await self.push_frame(frame, direction=FrameDirection.DOWNSTREAM) + + # Clear out the buffered user text + self._user_text_buffer = "" + + # We're no longer waiting for a trigger transcription + self._waiting_for_trigger_transcription = False # # context @@ -1058,23 +1146,26 @@ class AWSNovaSonicLLMService(LLMService): *, user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> AWSNovaSonicContextAggregatorPair: + ) -> LLMContextAggregatorPair: """Create context aggregator pair for managing conversation context. + NOTE: this method exists only for backward compatibility. New code + should instead do: + context = LLMContext(...) + context_aggregator = LLMContextAggregatorPair(context) + Args: - context: The OpenAI LLM context to upgrade. + context: The OpenAI LLM context. user_params: Parameters for the user context aggregator. assistant_params: Parameters for the assistant context aggregator. Returns: A pair of user and assistant context aggregators. """ - context.set_llm_adapter(self.get_llm_adapter()) - - user = AWSNovaSonicUserContextAggregator(context=context, params=user_params) - assistant = AWSNovaSonicAssistantContextAggregator(context=context, params=assistant_params) - - return AWSNovaSonicContextAggregatorPair(user, assistant) + context = LLMContext.from_openai_context(context) + return LLMContextAggregatorPair( + context, user_params=user_params, assistant_params=assistant_params + ) # # assistant response trigger (HACK) @@ -1112,6 +1203,8 @@ class AWSNovaSonicLLMService(LLMService): try: logger.debug("Sending assistant response trigger...") + self._waiting_for_trigger_transcription = True + chunk_duration = 0.02 # what we might get from InputAudioRawFrame chunk_size = int( chunk_duration diff --git a/src/pipecat/services/aws_nova_sonic/context.py b/src/pipecat/services/aws_nova_sonic/context.py deleted file mode 100644 index 0ce5ce033..000000000 --- a/src/pipecat/services/aws_nova_sonic/context.py +++ /dev/null @@ -1,367 +0,0 @@ -# -# Copyright (c) 2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""Context management for AWS Nova Sonic LLM service. - -This module provides specialized context aggregators and message handling for AWS Nova Sonic, -including conversation history management and role-specific message processing. -""" - -import copy -from dataclasses import dataclass, field -from enum import Enum - -from loguru import logger - -from pipecat.frames.frames import ( - BotStoppedSpeakingFrame, - DataFrame, - Frame, - FunctionCallResultFrame, - InterruptionFrame, - LLMFullResponseEndFrame, - LLMFullResponseStartFrame, - LLMMessagesAppendFrame, - LLMMessagesUpdateFrame, - LLMSetToolChoiceFrame, - LLMSetToolsFrame, - TextFrame, - UserImageRawFrame, -) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame -from pipecat.services.openai.llm import ( - OpenAIAssistantContextAggregator, - OpenAIUserContextAggregator, -) - - -class Role(Enum): - """Roles supported in AWS Nova Sonic conversations. - - Parameters: - SYSTEM: System-level messages (not used in conversation history). - USER: Messages sent by the user. - ASSISTANT: Messages sent by the assistant. - TOOL: Messages sent by tools (not used in conversation history). - """ - - SYSTEM = "SYSTEM" - USER = "USER" - ASSISTANT = "ASSISTANT" - TOOL = "TOOL" - - -@dataclass -class AWSNovaSonicConversationHistoryMessage: - """A single message in AWS Nova Sonic conversation history. - - Parameters: - role: The role of the message sender (USER or ASSISTANT only). - text: The text content of the message. - """ - - role: Role # only USER and ASSISTANT - text: str - - -@dataclass -class AWSNovaSonicConversationHistory: - """Complete conversation history for AWS Nova Sonic initialization. - - Parameters: - system_instruction: System-level instruction for the conversation. - messages: List of conversation messages between user and assistant. - """ - - system_instruction: str = None - messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list) - - -class AWSNovaSonicLLMContext(OpenAILLMContext): - """Specialized LLM context for AWS Nova Sonic service. - - Extends OpenAI context with Nova Sonic-specific message handling, - conversation history management, and text buffering capabilities. - """ - - def __init__(self, messages=None, tools=None, **kwargs): - """Initialize AWS Nova Sonic LLM context. - - Args: - messages: Initial messages for the context. - tools: Available tools for the context. - **kwargs: Additional arguments passed to parent class. - """ - super().__init__(messages=messages, tools=tools, **kwargs) - self.__setup_local() - - def __setup_local(self, system_instruction: str = ""): - self._assistant_text = "" - self._user_text = "" - self._system_instruction = system_instruction - - @staticmethod - def upgrade_to_nova_sonic( - obj: OpenAILLMContext, system_instruction: str - ) -> "AWSNovaSonicLLMContext": - """Upgrade an OpenAI context to AWS Nova Sonic context. - - Args: - obj: The OpenAI context to upgrade. - system_instruction: System instruction for the context. - - Returns: - The upgraded AWS Nova Sonic context. - """ - if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext): - obj.__class__ = AWSNovaSonicLLMContext - obj.__setup_local(system_instruction) - return obj - - # NOTE: this method has the side-effect of updating _system_instruction from messages - def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory: - """Get conversation history for initializing AWS Nova Sonic session. - - Processes stored messages and extracts system instruction and conversation - history in the format expected by AWS Nova Sonic. - - Returns: - Formatted conversation history with system instruction and messages. - """ - history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction) - - # Bail if there are no messages - if not self.messages: - return history - - messages = copy.deepcopy(self.messages) - - # If we have a "system" message as our first message, let's pull that out into "instruction" - if messages[0].get("role") == "system": - system = messages.pop(0) - content = system.get("content") - if isinstance(content, str): - history.system_instruction = content - elif isinstance(content, list): - history.system_instruction = content[0].get("text") - if history.system_instruction: - self._system_instruction = history.system_instruction - - # Process remaining messages to fill out conversation history. - # Nova Sonic supports "user" and "assistant" messages in history. - for message in messages: - history_message = self.from_standard_message(message) - if history_message: - history.messages.append(history_message) - - return history - - def get_messages_for_persistent_storage(self): - """Get messages formatted for persistent storage. - - Returns: - List of messages including system instruction if present. - """ - messages = super().get_messages_for_persistent_storage() - # If we have a system instruction and messages doesn't already contain it, add it - if self._system_instruction and not (messages and messages[0].get("role") == "system"): - messages.insert(0, {"role": "system", "content": self._system_instruction}) - return messages - - def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage: - """Convert standard message format to Nova Sonic format. - - Args: - message: Standard message dictionary to convert. - - Returns: - Nova Sonic conversation history message, or None if not convertible. - """ - role = message.get("role") - if message.get("role") == "user" or message.get("role") == "assistant": - content = message.get("content") - if isinstance(message.get("content"), list): - content = "" - for c in message.get("content"): - if c.get("type") == "text": - content += " " + c.get("text") - else: - logger.error( - f"Unhandled content type in context message: {c.get('type')} - {message}" - ) - # There won't be content if this is an assistant tool call entry. - # We're ignoring those since they can't be loaded into AWS Nova Sonic conversation - # history - if content: - return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content) - # NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova - # Sonic conversation history - - def buffer_user_text(self, text): - """Buffer user text for later flushing to context. - - Args: - text: User text to buffer. - """ - self._user_text += f" {text}" if self._user_text else text - # logger.debug(f"User text buffered: {self._user_text}") - - def flush_aggregated_user_text(self) -> str: - """Flush buffered user text to context as a complete message. - - Returns: - The flushed user text, or empty string if no text was buffered. - """ - if not self._user_text: - return "" - user_text = self._user_text - message = { - "role": "user", - "content": [{"type": "text", "text": user_text}], - } - self._user_text = "" - self.add_message(message) - # logger.debug(f"Context updated (user): {self.get_messages_for_logging()}") - return user_text - - def buffer_assistant_text(self, text): - """Buffer assistant text for later flushing to context. - - Args: - text: Assistant text to buffer. - """ - self._assistant_text += text - # logger.debug(f"Assistant text buffered: {self._assistant_text}") - - def flush_aggregated_assistant_text(self): - """Flush buffered assistant text to context as a complete message.""" - if not self._assistant_text: - return - message = { - "role": "assistant", - "content": [{"type": "text", "text": self._assistant_text}], - } - self._assistant_text = "" - self.add_message(message) - # logger.debug(f"Context updated (assistant): {self.get_messages_for_logging()}") - - -@dataclass -class AWSNovaSonicMessagesUpdateFrame(DataFrame): - """Frame containing updated AWS Nova Sonic context. - - Parameters: - context: The updated AWS Nova Sonic LLM context. - """ - - context: AWSNovaSonicLLMContext - - -class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator): - """Context aggregator for user messages in AWS Nova Sonic conversations. - - Extends the OpenAI user context aggregator to emit Nova Sonic-specific - context update frames. - """ - - async def process_frame( - self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM - ): - """Process frames and emit Nova Sonic-specific context updates. - - Args: - frame: The frame to process. - direction: The direction the frame is traveling. - """ - await super().process_frame(frame, direction) - - # Parent does not push LLMMessagesUpdateFrame - if isinstance(frame, LLMMessagesUpdateFrame): - await self.push_frame(AWSNovaSonicMessagesUpdateFrame(context=self._context)) - - -class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator): - """Context aggregator for assistant messages in AWS Nova Sonic conversations. - - Provides specialized handling for assistant responses and function calls - in AWS Nova Sonic context, with custom frame processing logic. - """ - - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process frames with Nova Sonic-specific logic. - - Args: - frame: The frame to process. - direction: The direction the frame is traveling. - """ - # HACK: For now, disable the context aggregator by making it just pass through all frames - # that the parent handles (except the function call stuff, which we still need). - # For an explanation of this hack, see - # AWSNovaSonicLLMService._report_assistant_response_text_added. - if isinstance( - frame, - ( - InterruptionFrame, - LLMFullResponseStartFrame, - LLMFullResponseEndFrame, - TextFrame, - LLMMessagesAppendFrame, - LLMMessagesUpdateFrame, - LLMSetToolsFrame, - LLMSetToolChoiceFrame, - UserImageRawFrame, - BotStoppedSpeakingFrame, - ), - ): - await self.push_frame(frame, direction) - else: - await super().process_frame(frame, direction) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - """Handle function call results for AWS Nova Sonic. - - Args: - frame: The function call result frame to handle. - """ - await super().handle_function_call_result(frame) - - # The standard function callback code path pushes the FunctionCallResultFrame from the LLM - # itself, so we didn't have a chance to add the result to the AWS Nova Sonic server-side - # context. Let's push a special frame to do that. - await self.push_frame( - AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM - ) - - -@dataclass -class AWSNovaSonicContextAggregatorPair: - """Pair of user and assistant context aggregators for AWS Nova Sonic. - - Parameters: - _user: The user context aggregator. - _assistant: The assistant context aggregator. - """ - - _user: AWSNovaSonicUserContextAggregator - _assistant: AWSNovaSonicAssistantContextAggregator - - def user(self) -> AWSNovaSonicUserContextAggregator: - """Get the user context aggregator. - - Returns: - The user context aggregator instance. - """ - return self._user - - def assistant(self) -> AWSNovaSonicAssistantContextAggregator: - """Get the assistant context aggregator. - - Returns: - The assistant context aggregator instance. - """ - return self._assistant diff --git a/tests/test_llm_context.py b/tests/test_llm_context.py new file mode 100644 index 000000000..3dd84bd2c --- /dev/null +++ b/tests/test_llm_context.py @@ -0,0 +1,208 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage + + +class TestGetMessagesForPersistentStorage(unittest.TestCase): + """Test suite for LLMContext.get_messages_for_persistent_storage method.""" + + def test_no_system_instruction_returns_messages_as_is(self): + """Test that without system instruction, messages are returned unchanged.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + context = LLMContext(messages=messages) + + result = context.get_messages_for_persistent_storage() + + self.assertEqual(result, messages) + self.assertEqual(len(result), 2) + + def test_empty_messages_with_system_instruction_adds_system_message(self): + """Test that system instruction is added when messages list is empty.""" + context = LLMContext() + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[0]["content"], system_instruction) + + def test_non_system_first_message_prepends_system_instruction(self): + """Test that system instruction is prepended when first message is not system.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + context = LLMContext(messages=messages) + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + self.assertEqual(len(result), 3) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[0]["content"], system_instruction) + self.assertEqual(result[1], messages[0]) + self.assertEqual(result[2], messages[1]) + + def test_existing_system_message_not_duplicated(self): + """Test that system instruction is not added when first message is already system.""" + messages = [ + {"role": "system", "content": "Existing system message"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + context = LLMContext(messages=messages) + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + self.assertEqual(len(result), 3) + self.assertEqual(result, messages) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[0]["content"], "Existing system message") + + def test_empty_system_instruction_does_not_add_message(self): + """Test that empty system instruction does not add a system message.""" + messages = [{"role": "user", "content": "Hello"}] + context = LLMContext(messages=messages) + + result = context.get_messages_for_persistent_storage("") + + self.assertEqual(result, messages) + self.assertEqual(len(result), 1) + + def test_none_system_instruction_does_not_add_message(self): + """Test that None system instruction does not add a system message.""" + messages = [{"role": "user", "content": "Hello"}] + context = LLMContext(messages=messages) + + result = context.get_messages_for_persistent_storage(None) + + self.assertEqual(result, messages) + self.assertEqual(len(result), 1) + + def test_whitespace_only_system_instruction_adds_message(self): + """Test that whitespace-only system instruction still adds a system message.""" + messages = [{"role": "user", "content": "Hello"}] + context = LLMContext(messages=messages) + system_instruction = " " + + result = context.get_messages_for_persistent_storage(system_instruction) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[0]["content"], system_instruction) + + def test_with_llm_specific_messages(self): + """Test that method works correctly with LLMSpecificMessage objects.""" + llm_specific = LLMSpecificMessage( + llm="test-llm", message={"role": "user", "content": "Specific"} + ) + messages = [{"role": "user", "content": "Standard message"}, llm_specific] + context = LLMContext(messages=messages) + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + self.assertEqual(len(result), 3) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[0]["content"], system_instruction) + self.assertEqual(result[1], messages[0]) + self.assertEqual(result[2], llm_specific) + + def test_system_message_detection_case_sensitivity(self): + """Test that system message detection is case sensitive.""" + messages = [ + {"role": "System", "content": "Mixed case system"}, # Capital S + {"role": "user", "content": "Hello"}, + ] + context = LLMContext(messages=messages) + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + # Should prepend because "System" != "system" + self.assertEqual(len(result), 3) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[0]["content"], system_instruction) + self.assertEqual(result[1], messages[0]) + + def test_message_without_role_key_does_not_crash(self): + """Test that messages without 'role' key are handled gracefully.""" + messages = [{"content": "Message without role"}, {"role": "user", "content": "Hello"}] + context = LLMContext(messages=messages) + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + # Should prepend system instruction since first message doesn't have role="system" + self.assertEqual(len(result), 3) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[0]["content"], system_instruction) + + def test_original_messages_not_modified(self): + """Test that the original messages list is not modified.""" + original_messages = [{"role": "user", "content": "Hello"}] + context = LLMContext(messages=original_messages) + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + # Original messages should remain unchanged + self.assertEqual(len(original_messages), 1) + self.assertEqual(original_messages[0]["role"], "user") + + # Result should have system message prepended + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[1], original_messages[0]) + + def test_complex_message_structure_preserved(self): + """Test that complex message structures are preserved.""" + complex_message = { + "role": "user", + "content": [ + {"type": "text", "text": "Complex message"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}, + ], + } + messages = [complex_message] + context = LLMContext(messages=messages) + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[1], complex_message) + self.assertEqual(result[1]["content"], complex_message["content"]) + + def test_deep_copy_prevents_nested_mutation(self): + """Test that deep copy prevents mutation of nested message content.""" + nested_content = {"nested": {"data": "original"}} + complex_message = {"role": "user", "content": nested_content} + messages = [complex_message] + context = LLMContext(messages=messages) + system_instruction = "You are a helpful assistant." + + result = context.get_messages_for_persistent_storage(system_instruction) + + # Modify the nested content in the result + result[1]["content"]["nested"]["data"] = "modified" + + # Original message should remain unchanged + self.assertEqual(complex_message["content"]["nested"]["data"], "original") + self.assertEqual(context.get_messages()[0]["content"]["nested"]["data"], "original") + + +if __name__ == "__main__": + unittest.main()