Add a LLMMessagesTransformFrame to facilitate programmatically editing context in a frame-based way.

The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages *at that point in time*, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.
This commit is contained in:
Paul Kompfner
2026-04-02 18:03:22 -04:00
parent 5490820338
commit 4eebfd65d9
5 changed files with 113 additions and 3 deletions

3
changelog/4231.added.md Normal file
View File

@@ -0,0 +1,3 @@
- Added `LLMMessagesTransformFrame` to facilitate programmatically editing context in a frame-based way.
The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages _at that point in time_, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.

View File

@@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str
from pipecat.utils.utils import obj_count, obj_id
if TYPE_CHECKING:
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.services.settings import ServiceSettings
from pipecat.utils.context.llm_context_summarization import LLMContextSummaryConfig
@@ -587,6 +587,25 @@ class LLMMessagesUpdateFrame(DataFrame):
run_llm: Optional[bool] = None
@dataclass
class LLMMessagesTransformFrame(DataFrame):
"""Frame containing a transform function to modify the current context's LLM messages.
A frame containing a transform function that takes the context's current list
of LLM messages and returns a modified list.
Only compatible with LLMContext and not the deprecated OpenAILLMContext.
Parameters:
transform: A function that takes a list of messages and returns a
modified list.
run_llm: Whether the context update should be sent to the LLM.
"""
transform: Callable[[List["LLMContextMessage"]], List["LLMContextMessage"]]
run_llm: Optional[bool] = None
@dataclass
class LLMSetToolsFrame(DataFrame):
"""Frame containing tools for LLM function calling.

View File

@@ -19,7 +19,7 @@ import base64
import io
import wave
from dataclasses import dataclass
from typing import Any, List, Optional, TypeAlias, Union
from typing import Any, Callable, List, Optional, TypeAlias, Union
from loguru import logger
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
@@ -266,6 +266,17 @@ class LLMContext:
"""
self._messages[:] = messages
def transform_messages(
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
):
"""Transform the current messages using the provided function.
Args:
transform: A function that takes the current list of messages and returns
a modified list of messages to set in the context.
"""
self.set_messages(transform(self._messages))
def set_tools(self, tools: ToolsSchema | NotGiven = NOT_GIVEN):
"""Set the available tools for the LLM.

View File

@@ -16,7 +16,7 @@ import json
import warnings
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Set, Type
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type
from loguru import logger
@@ -42,6 +42,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesTransformFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
LLMSetToolChoiceFrame,
@@ -315,6 +316,17 @@ class LLMContextAggregator(FrameProcessor):
"""
self._context.set_messages(messages)
def transform_messages(
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
):
"""Transform the context messages using a provided function.
Args:
transform: A function that takes the current list of messages and returns
a modified list of messages to set in the context.
"""
self._context.transform_messages(transform)
def set_tools(self, tools: ToolsSchema | NotGiven):
"""Set tools in the context.
@@ -512,6 +524,8 @@ class LLMUserAggregator(LLMContextAggregator):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMMessagesTransformFrame):
await self._handle_llm_messages_transform(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
# Push the LLMSetToolsFrame as well, since speech-to-speech LLM
@@ -635,6 +649,11 @@ class LLMUserAggregator(LLMContextAggregator):
if frame.run_llm:
await self.push_context_frame()
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
self.transform_messages(frame.transform)
if frame.run_llm:
await self.push_context_frame()
async def _handle_transcription(self, frame: TranscriptionFrame):
text = frame.text
@@ -923,6 +942,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMMessagesTransformFrame):
await self._handle_llm_messages_transform(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
@@ -982,6 +1003,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
self.transform_messages(frame.transform)
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_interruptions(self, frame: InterruptionFrame):
await self._trigger_assistant_turn_stopped()
await self.reset()

View File

@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesTransformFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
LLMTextFrame,
@@ -180,6 +181,56 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
assert context.messages[0]["content"] == "You are a helpful assistant."
assert context.messages[1]["content"] == "Hello!"
async def test_llm_messages_transform(self):
context = LLMContext()
# Set up initial messages
context.set_messages(
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]
)
pipeline = Pipeline([LLMUserAggregator(context)])
# Transform that keeps only user messages
def keep_user_messages(messages):
return [m for m in messages if m["role"] == "user"]
frames_to_send = [LLMMessagesTransformFrame(transform=keep_user_messages)]
expected_down_frames = [
SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert len(context.messages) == 2
assert context.messages[0]["content"] == "Hello"
assert context.messages[1]["content"] == "How are you?"
async def test_llm_messages_transform_run(self):
context = LLMContext()
# Set up initial messages
context.set_messages([{"role": "user", "content": "Hello"}])
pipeline = Pipeline([LLMUserAggregator(context)])
# Transform that modifies the content
def uppercase_content(messages):
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
frames_to_send = [LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)]
expected_down_frames = [SpeechControlParamsFrame, LLMContextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert context.messages[0]["content"] == "HELLO"
async def test_default_user_turn_strategies(self):
context = LLMContext()
user_aggregator = LLMUserAggregator(