Get rid of LLMContext.get_messages_for_persistent_storage().
The reason for its `system_instruction` argument was to support usage with LLMs where you might pass the system instruction as a parameter to the `LLMService` rather than specifying it in the context. But as I thought about it more I became unconvinced that the `system_instruction` argument was really beneficial: - If you specified your system instruction in your context in the first place, it'll still be there when you read messages for persistent storage - If you didn't specify your system instruction in the context and instead passed it in as an `LLMService` parameter, you most likely *don't* want it to be in the context when you read messages for persistent storage - ...and if you really really do need to inject it at the start of the context, it's quite easy to do anyway And if we remove the `system_instruction` argument from `get_messages_for_persistent_storage()`, then it's essentially just `get_messages()`.
This commit is contained in:
@@ -67,11 +67,11 @@ async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}"
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}"
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
messages = params.context.get_messages()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
|
||||
@@ -68,12 +68,11 @@ async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}"
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}"
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
# todo: extract 'system' into the first message in the list
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
messages = params.context.get_messages()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
|
||||
@@ -86,11 +86,11 @@ async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}"
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}"
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
messages = params.context.get_messages()
|
||||
# remove the last message (the instruction to save the context)
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
|
||||
@@ -77,7 +77,7 @@ async def save_conversation(params: FunctionCallParams):
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
messages = params.context.get_messages()
|
||||
# remove the last few messages. in reverse order, they are:
|
||||
# - the in progress save tool call
|
||||
# - the invocation of the save tool call
|
||||
|
||||
@@ -131,21 +131,6 @@ class LLMContext:
|
||||
)
|
||||
return filtered_messages
|
||||
|
||||
def get_messages_for_persistent_storage(
|
||||
self, system_instruction: Optional[str] = None
|
||||
) -> List[LLMContextMessage]:
|
||||
"""Get messages formatted for persistent storage.
|
||||
|
||||
Args:
|
||||
system_instruction: Optional system instruction to ensure is
|
||||
included as the first message in the returned list, if not
|
||||
already present.
|
||||
"""
|
||||
messages = copy.deepcopy(self.get_messages())
|
||||
if system_instruction and (not messages or messages[0].get("role") != "system"):
|
||||
messages.insert(0, {"role": "system", "content": system_instruction})
|
||||
return messages
|
||||
|
||||
@property
|
||||
def tools(self) -> ToolsSchema | NotGiven:
|
||||
"""Get the tools list.
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
|
||||
|
||||
class TestGetMessagesForPersistentStorage(unittest.TestCase):
|
||||
"""Test suite for LLMContext.get_messages_for_persistent_storage method."""
|
||||
|
||||
def test_no_system_instruction_returns_messages_as_is(self):
|
||||
"""Test that without system instruction, messages are returned unchanged."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
result = context.get_messages_for_persistent_storage()
|
||||
|
||||
self.assertEqual(result, messages)
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
def test_empty_messages_with_system_instruction_adds_system_message(self):
|
||||
"""Test that system instruction is added when messages list is empty."""
|
||||
context = LLMContext()
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
|
||||
def test_non_system_first_message_prepends_system_instruction(self):
|
||||
"""Test that system instruction is prepended when first message is not system."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
self.assertEqual(result[1], messages[0])
|
||||
self.assertEqual(result[2], messages[1])
|
||||
|
||||
def test_existing_system_message_not_duplicated(self):
|
||||
"""Test that system instruction is not added when first message is already system."""
|
||||
messages = [
|
||||
{"role": "system", "content": "Existing system message"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result, messages)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], "Existing system message")
|
||||
|
||||
def test_empty_system_instruction_does_not_add_message(self):
|
||||
"""Test that empty system instruction does not add a system message."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
result = context.get_messages_for_persistent_storage("")
|
||||
|
||||
self.assertEqual(result, messages)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_none_system_instruction_does_not_add_message(self):
|
||||
"""Test that None system instruction does not add a system message."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
result = context.get_messages_for_persistent_storage(None)
|
||||
|
||||
self.assertEqual(result, messages)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_whitespace_only_system_instruction_adds_message(self):
|
||||
"""Test that whitespace-only system instruction still adds a system message."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = " "
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
|
||||
def test_with_llm_specific_messages(self):
|
||||
"""Test that method works correctly with LLMSpecificMessage objects."""
|
||||
llm_specific = LLMSpecificMessage(
|
||||
llm="test-llm", message={"role": "user", "content": "Specific"}
|
||||
)
|
||||
messages = [{"role": "user", "content": "Standard message"}, llm_specific]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
self.assertEqual(result[1], messages[0])
|
||||
self.assertEqual(result[2], llm_specific)
|
||||
|
||||
def test_system_message_detection_case_sensitivity(self):
|
||||
"""Test that system message detection is case sensitive."""
|
||||
messages = [
|
||||
{"role": "System", "content": "Mixed case system"}, # Capital S
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
# Should prepend because "System" != "system"
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
self.assertEqual(result[1], messages[0])
|
||||
|
||||
def test_message_without_role_key_does_not_crash(self):
|
||||
"""Test that messages without 'role' key are handled gracefully."""
|
||||
messages = [{"content": "Message without role"}, {"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
# Should prepend system instruction since first message doesn't have role="system"
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
|
||||
def test_original_messages_not_modified(self):
|
||||
"""Test that the original messages list is not modified."""
|
||||
original_messages = [{"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=original_messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
# Original messages should remain unchanged
|
||||
self.assertEqual(len(original_messages), 1)
|
||||
self.assertEqual(original_messages[0]["role"], "user")
|
||||
|
||||
# Result should have system message prepended
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[1], original_messages[0])
|
||||
|
||||
def test_complex_message_structure_preserved(self):
|
||||
"""Test that complex message structures are preserved."""
|
||||
complex_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Complex message"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}},
|
||||
],
|
||||
}
|
||||
messages = [complex_message]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[1], complex_message)
|
||||
self.assertEqual(result[1]["content"], complex_message["content"])
|
||||
|
||||
def test_deep_copy_prevents_nested_mutation(self):
|
||||
"""Test that deep copy prevents mutation of nested message content."""
|
||||
nested_content = {"nested": {"data": "original"}}
|
||||
complex_message = {"role": "user", "content": nested_content}
|
||||
messages = [complex_message]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
# Modify the nested content in the result
|
||||
result[1]["content"]["nested"]["data"] = "modified"
|
||||
|
||||
# Original message should remain unchanged
|
||||
self.assertEqual(complex_message["content"]["nested"]["data"], "original")
|
||||
self.assertEqual(context.get_messages()[0]["content"]["nested"]["data"], "original")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user