diff --git a/changelog/4231.added.md b/changelog/4231.added.md new file mode 100644 index 000000000..1b47e1b78 --- /dev/null +++ b/changelog/4231.added.md @@ -0,0 +1,3 @@ +- Added `LLMMessagesTransformFrame` to facilitate programmatically editing context in a frame-based way. + + The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages _at that point in time_, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration. diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 86a93825b..548082936 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str from pipecat.utils.utils import obj_count, obj_id if TYPE_CHECKING: - from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven + from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven from pipecat.processors.frame_processor import FrameProcessor from pipecat.services.settings import ServiceSettings from pipecat.utils.context.llm_context_summarization import LLMContextSummaryConfig @@ -587,6 +587,25 @@ class LLMMessagesUpdateFrame(DataFrame): run_llm: Optional[bool] = None +@dataclass +class LLMMessagesTransformFrame(DataFrame): + """Frame containing a transform function to modify the current context's LLM messages. + + A frame containing a transform function that takes the context's current list + of LLM messages and returns a modified list. + + Only compatible with LLMContext and not the deprecated OpenAILLMContext. + + Parameters: + transform: A function that takes a list of messages and returns a + modified list. + run_llm: Whether the context update should be sent to the LLM. + """ + + transform: Callable[[List["LLMContextMessage"]], List["LLMContextMessage"]] + run_llm: Optional[bool] = None + + @dataclass class LLMSetToolsFrame(DataFrame): """Frame containing tools for LLM function calling. diff --git a/src/pipecat/processors/aggregators/llm_context.py b/src/pipecat/processors/aggregators/llm_context.py index b36aad4ca..a5d96c1d0 100644 --- a/src/pipecat/processors/aggregators/llm_context.py +++ b/src/pipecat/processors/aggregators/llm_context.py @@ -19,7 +19,7 @@ import base64 import io import wave from dataclasses import dataclass -from typing import Any, List, Optional, TypeAlias, Union +from typing import Any, Callable, List, Optional, TypeAlias, Union from loguru import logger from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN @@ -266,6 +266,17 @@ class LLMContext: """ self._messages[:] = messages + def transform_messages( + self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]] + ): + """Transform the current messages using the provided function. + + Args: + transform: A function that takes the current list of messages and returns + a modified list of messages to set in the context. + """ + self.set_messages(transform(self._messages)) + def set_tools(self, tools: ToolsSchema | NotGiven = NOT_GIVEN): """Set the available tools for the LLM. diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index 911035fdc..e7a581155 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -16,7 +16,7 @@ import json import warnings from abc import abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional, Set, Type +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type from loguru import logger @@ -42,6 +42,7 @@ from pipecat.frames.frames import ( LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesAppendFrame, + LLMMessagesTransformFrame, LLMMessagesUpdateFrame, LLMRunFrame, LLMSetToolChoiceFrame, @@ -315,6 +316,17 @@ class LLMContextAggregator(FrameProcessor): """ self._context.set_messages(messages) + def transform_messages( + self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]] + ): + """Transform the context messages using a provided function. + + Args: + transform: A function that takes the current list of messages and returns + a modified list of messages to set in the context. + """ + self._context.transform_messages(transform) + def set_tools(self, tools: ToolsSchema | NotGiven): """Set tools in the context. @@ -512,6 +524,8 @@ class LLMUserAggregator(LLMContextAggregator): await self._handle_llm_messages_append(frame) elif isinstance(frame, LLMMessagesUpdateFrame): await self._handle_llm_messages_update(frame) + elif isinstance(frame, LLMMessagesTransformFrame): + await self._handle_llm_messages_transform(frame) elif isinstance(frame, LLMSetToolsFrame): self.set_tools(frame.tools) # Push the LLMSetToolsFrame as well, since speech-to-speech LLM @@ -635,6 +649,11 @@ class LLMUserAggregator(LLMContextAggregator): if frame.run_llm: await self.push_context_frame() + async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame): + self.transform_messages(frame.transform) + if frame.run_llm: + await self.push_context_frame() + async def _handle_transcription(self, frame: TranscriptionFrame): text = frame.text @@ -923,6 +942,8 @@ class LLMAssistantAggregator(LLMContextAggregator): await self._handle_llm_messages_append(frame) elif isinstance(frame, LLMMessagesUpdateFrame): await self._handle_llm_messages_update(frame) + elif isinstance(frame, LLMMessagesTransformFrame): + await self._handle_llm_messages_transform(frame) elif isinstance(frame, LLMSetToolsFrame): self.set_tools(frame.tools) elif isinstance(frame, LLMSetToolChoiceFrame): @@ -982,6 +1003,11 @@ class LLMAssistantAggregator(LLMContextAggregator): if frame.run_llm: await self.push_context_frame(FrameDirection.UPSTREAM) + async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame): + self.transform_messages(frame.transform) + if frame.run_llm: + await self.push_context_frame(FrameDirection.UPSTREAM) + async def _handle_interruptions(self, frame: InterruptionFrame): await self._trigger_assistant_turn_stopped() await self.reset() diff --git a/tests/test_context_aggregators_universal.py b/tests/test_context_aggregators_universal.py index a0b961fa2..8f10cae0a 100644 --- a/tests/test_context_aggregators_universal.py +++ b/tests/test_context_aggregators_universal.py @@ -22,6 +22,7 @@ from pipecat.frames.frames import ( LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesAppendFrame, + LLMMessagesTransformFrame, LLMMessagesUpdateFrame, LLMRunFrame, LLMTextFrame, @@ -180,6 +181,56 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase): assert context.messages[0]["content"] == "You are a helpful assistant." assert context.messages[1]["content"] == "Hello!" + async def test_llm_messages_transform(self): + context = LLMContext() + # Set up initial messages + context.set_messages( + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + ) + + pipeline = Pipeline([LLMUserAggregator(context)]) + + # Transform that keeps only user messages + def keep_user_messages(messages): + return [m for m in messages if m["role"] == "user"] + + frames_to_send = [LLMMessagesTransformFrame(transform=keep_user_messages)] + expected_down_frames = [ + SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False + ] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert len(context.messages) == 2 + assert context.messages[0]["content"] == "Hello" + assert context.messages[1]["content"] == "How are you?" + + async def test_llm_messages_transform_run(self): + context = LLMContext() + # Set up initial messages + context.set_messages([{"role": "user", "content": "Hello"}]) + + pipeline = Pipeline([LLMUserAggregator(context)]) + + # Transform that modifies the content + def uppercase_content(messages): + return [{"role": m["role"], "content": m["content"].upper()} for m in messages] + + frames_to_send = [LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)] + expected_down_frames = [SpeechControlParamsFrame, LLMContextFrame] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == "HELLO" + async def test_default_user_turn_strategies(self): context = LLMContext() user_aggregator = LLMUserAggregator(