diff --git a/changelog/4231.added.md b/changelog/4231.added.md new file mode 100644 index 000000000..1b47e1b78 --- /dev/null +++ b/changelog/4231.added.md @@ -0,0 +1,3 @@ +- Added `LLMMessagesTransformFrame` to facilitate programmatically editing context in a frame-based way. + + The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages _at that point in time_, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration. diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 86a93825b..1d3e841b2 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str from pipecat.utils.utils import obj_count, obj_id if TYPE_CHECKING: - from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven + from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven from pipecat.processors.frame_processor import FrameProcessor from pipecat.services.settings import ServiceSettings from pipecat.utils.context.llm_context_summarization import LLMContextSummaryConfig @@ -587,6 +587,23 @@ class LLMMessagesUpdateFrame(DataFrame): run_llm: Optional[bool] = None +@dataclass +class LLMMessagesTransformFrame(DataFrame): + """Frame containing a transform function to modify the current context's LLM messages. + + A frame containing a transform function that takes the context's current list + of LLM messages and returns a modified list. + + Parameters: + transform: A function that takes a list of messages and returns a + modified list. + run_llm: Whether the context update should be sent to the LLM. + """ + + transform: Callable[[List["LLMContextMessage"]], List["LLMContextMessage"]] + run_llm: Optional[bool] = None + + @dataclass class LLMSetToolsFrame(DataFrame): """Frame containing tools for LLM function calling. diff --git a/src/pipecat/processors/aggregators/llm_context.py b/src/pipecat/processors/aggregators/llm_context.py index b36aad4ca..a5d96c1d0 100644 --- a/src/pipecat/processors/aggregators/llm_context.py +++ b/src/pipecat/processors/aggregators/llm_context.py @@ -19,7 +19,7 @@ import base64 import io import wave from dataclasses import dataclass -from typing import Any, List, Optional, TypeAlias, Union +from typing import Any, Callable, List, Optional, TypeAlias, Union from loguru import logger from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN @@ -266,6 +266,17 @@ class LLMContext: """ self._messages[:] = messages + def transform_messages( + self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]] + ): + """Transform the current messages using the provided function. + + Args: + transform: A function that takes the current list of messages and returns + a modified list of messages to set in the context. + """ + self.set_messages(transform(self._messages)) + def set_tools(self, tools: ToolsSchema | NotGiven = NOT_GIVEN): """Set the available tools for the LLM. diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index 911035fdc..e7a581155 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -16,7 +16,7 @@ import json import warnings from abc import abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional, Set, Type +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type from loguru import logger @@ -42,6 +42,7 @@ from pipecat.frames.frames import ( LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesAppendFrame, + LLMMessagesTransformFrame, LLMMessagesUpdateFrame, LLMRunFrame, LLMSetToolChoiceFrame, @@ -315,6 +316,17 @@ class LLMContextAggregator(FrameProcessor): """ self._context.set_messages(messages) + def transform_messages( + self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]] + ): + """Transform the context messages using a provided function. + + Args: + transform: A function that takes the current list of messages and returns + a modified list of messages to set in the context. + """ + self._context.transform_messages(transform) + def set_tools(self, tools: ToolsSchema | NotGiven): """Set tools in the context. @@ -512,6 +524,8 @@ class LLMUserAggregator(LLMContextAggregator): await self._handle_llm_messages_append(frame) elif isinstance(frame, LLMMessagesUpdateFrame): await self._handle_llm_messages_update(frame) + elif isinstance(frame, LLMMessagesTransformFrame): + await self._handle_llm_messages_transform(frame) elif isinstance(frame, LLMSetToolsFrame): self.set_tools(frame.tools) # Push the LLMSetToolsFrame as well, since speech-to-speech LLM @@ -635,6 +649,11 @@ class LLMUserAggregator(LLMContextAggregator): if frame.run_llm: await self.push_context_frame() + async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame): + self.transform_messages(frame.transform) + if frame.run_llm: + await self.push_context_frame() + async def _handle_transcription(self, frame: TranscriptionFrame): text = frame.text @@ -923,6 +942,8 @@ class LLMAssistantAggregator(LLMContextAggregator): await self._handle_llm_messages_append(frame) elif isinstance(frame, LLMMessagesUpdateFrame): await self._handle_llm_messages_update(frame) + elif isinstance(frame, LLMMessagesTransformFrame): + await self._handle_llm_messages_transform(frame) elif isinstance(frame, LLMSetToolsFrame): self.set_tools(frame.tools) elif isinstance(frame, LLMSetToolChoiceFrame): @@ -982,6 +1003,11 @@ class LLMAssistantAggregator(LLMContextAggregator): if frame.run_llm: await self.push_context_frame(FrameDirection.UPSTREAM) + async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame): + self.transform_messages(frame.transform) + if frame.run_llm: + await self.push_context_frame(FrameDirection.UPSTREAM) + async def _handle_interruptions(self, frame: InterruptionFrame): await self._trigger_assistant_turn_stopped() await self.reset() diff --git a/src/pipecat/tests/utils.py b/src/pipecat/tests/utils.py index ca18c4c4d..c837a58b7 100644 --- a/src/pipecat/tests/utils.py +++ b/src/pipecat/tests/utils.py @@ -127,6 +127,7 @@ async def run_test( expected_down_frames: Optional[Sequence[type]] = None, expected_up_frames: Optional[Sequence[type]] = None, frames_to_send: Sequence[Frame], + frames_to_send_direction: FrameDirection = FrameDirection.DOWNSTREAM, ignore_start: bool = True, observers: Optional[List[BaseObserver]] = None, pipeline_params: Optional[PipelineParams] = None, @@ -144,6 +145,9 @@ async def run_test( expected_down_frames: Expected frame types flowing downstream (optional). expected_up_frames: Expected frame types flowing upstream (optional). frames_to_send: Sequence of frames to send through the processor. + frames_to_send_direction: Direction to send frames_to_send. Downstream + frames are pushed from the beginning of the pipeline, upstream frames + from the end. Defaults to DOWNSTREAM. ignore_start: Whether to ignore StartFrames in frame validation. observers: Optional list of observers to attach to the pipeline. pipeline_params: Optional pipeline parameters. @@ -188,7 +192,7 @@ async def run_test( if isinstance(frame, SleepFrame): await asyncio.sleep(frame.sleep) else: - await task.queue_frame(frame) + await task.queue_frame(frame, frames_to_send_direction) if send_end_frame: await task.queue_frame(EndFrame()) diff --git a/tests/test_context_aggregators_universal.py b/tests/test_context_aggregators_universal.py index a0b961fa2..66705d090 100644 --- a/tests/test_context_aggregators_universal.py +++ b/tests/test_context_aggregators_universal.py @@ -22,6 +22,7 @@ from pipecat.frames.frames import ( LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesAppendFrame, + LLMMessagesTransformFrame, LLMMessagesUpdateFrame, LLMRunFrame, LLMTextFrame, @@ -48,6 +49,7 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMUserAggregator, LLMUserAggregatorParams, ) +from pipecat.processors.frame_processor import FrameDirection from pipecat.tests.utils import SleepFrame, run_test from pipecat.turns.user_mute import ( FirstSpeechUserMuteStrategy, @@ -90,9 +92,13 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase): ] ) ] + expected_down_frames = [ + SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False + ] await run_test( pipeline, frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, ) assert context.messages[0]["content"] == "Hi there!" @@ -133,9 +139,13 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase): ] ) ] + expected_down_frames = [ + SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False + ] await run_test( pipeline, frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, ) assert context.messages[0]["content"] == "Hi there!" @@ -180,6 +190,56 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase): assert context.messages[0]["content"] == "You are a helpful assistant." assert context.messages[1]["content"] == "Hello!" + async def test_llm_messages_transform(self): + context = LLMContext() + # Set up initial messages + context.set_messages( + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + ) + + pipeline = Pipeline([LLMUserAggregator(context)]) + + # Transform that keeps only user messages + def keep_user_messages(messages): + return [m for m in messages if m["role"] == "user"] + + frames_to_send = [LLMMessagesTransformFrame(transform=keep_user_messages)] + expected_down_frames = [ + SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False + ] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert len(context.messages) == 2 + assert context.messages[0]["content"] == "Hello" + assert context.messages[1]["content"] == "How are you?" + + async def test_llm_messages_transform_run(self): + context = LLMContext() + # Set up initial messages + context.set_messages([{"role": "user", "content": "Hello"}]) + + pipeline = Pipeline([LLMUserAggregator(context)]) + + # Transform that modifies the content + def uppercase_content(messages): + return [{"role": m["role"], "content": m["content"].upper()} for m in messages] + + frames_to_send = [LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)] + expected_down_frames = [SpeechControlParamsFrame, LLMContextFrame] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == "HELLO" + async def test_default_user_turn_strategies(self): context = LLMContext() user_aggregator = LLMUserAggregator( @@ -957,6 +1017,149 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(stop_messages), 1) self.assertEqual(stop_messages[0].content, "") + async def test_llm_run(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + expected_up_frames = [LLMContextFrame] + await run_test( + aggregator, + frames_to_send=[LLMRunFrame()], + frames_to_send_direction=FrameDirection.UPSTREAM, + expected_up_frames=expected_up_frames, + ) + + async def test_llm_messages_append(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + await run_test( + aggregator, + frames_to_send=[ + LLMMessagesAppendFrame( + messages=[ + { + "role": "user", + "content": "Hi there!", + } + ] + ) + ], + frames_to_send_direction=FrameDirection.UPSTREAM, + expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False + ) + assert context.messages[0]["content"] == "Hi there!" + + async def test_llm_messages_append_run(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + expected_up_frames = [LLMContextFrame] + await run_test( + aggregator, + frames_to_send=[ + LLMMessagesAppendFrame( + messages=[ + { + "role": "user", + "content": "Hi there!", + } + ], + run_llm=True, + ) + ], + frames_to_send_direction=FrameDirection.UPSTREAM, + expected_up_frames=expected_up_frames, + ) + assert context.messages[0]["content"] == "Hi there!" + + async def test_llm_messages_update(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + await run_test( + aggregator, + frames_to_send=[ + LLMMessagesUpdateFrame( + messages=[ + { + "role": "user", + "content": "Hi there!", + } + ] + ) + ], + frames_to_send_direction=FrameDirection.UPSTREAM, + expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False + ) + assert context.messages[0]["content"] == "Hi there!" + + async def test_llm_messages_update_run(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + await run_test( + aggregator, + frames_to_send=[ + LLMMessagesUpdateFrame( + messages=[ + { + "role": "user", + "content": "Hi there!", + } + ], + run_llm=True, + ) + ], + frames_to_send_direction=FrameDirection.UPSTREAM, + ) + assert context.messages[0]["content"] == "Hi there!" + + async def test_llm_messages_transform(self): + context = LLMContext() + context.set_messages( + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + ) + + aggregator = LLMAssistantAggregator(context) + + # Transform that keeps only user messages + def keep_user_messages(messages): + return [m for m in messages if m["role"] == "user"] + + await run_test( + aggregator, + frames_to_send=[LLMMessagesTransformFrame(transform=keep_user_messages)], + frames_to_send_direction=FrameDirection.UPSTREAM, + expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False + ) + assert len(context.messages) == 2 + assert context.messages[0]["content"] == "Hello" + assert context.messages[1]["content"] == "How are you?" + + async def test_llm_messages_transform_run(self): + context = LLMContext() + context.set_messages([{"role": "user", "content": "Hello"}]) + + aggregator = LLMAssistantAggregator(context) + + # Transform that modifies the content + def uppercase_content(messages): + return [{"role": m["role"], "content": m["content"].upper()} for m in messages] + + expected_up_frames = [LLMContextFrame] + await run_test( + aggregator, + frames_to_send=[LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)], + frames_to_send_direction=FrameDirection.UPSTREAM, + expected_up_frames=expected_up_frames, + ) + assert context.messages[0]["content"] == "HELLO" + if __name__ == "__main__": unittest.main()