Merge pull request #4231 from pipecat-ai/pk/llm-messages-transform-frame

Add a `LLMMessagesTransformFrame` to facilitate programmatically edit…
This commit is contained in:
kompfner
2026-04-03 11:54:34 -04:00
committed by GitHub
6 changed files with 268 additions and 4 deletions

3
changelog/4231.added.md Normal file
View File

@@ -0,0 +1,3 @@
- Added `LLMMessagesTransformFrame` to facilitate programmatically editing context in a frame-based way.
The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages _at that point in time_, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.

View File

@@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str
from pipecat.utils.utils import obj_count, obj_id
if TYPE_CHECKING:
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.services.settings import ServiceSettings
from pipecat.utils.context.llm_context_summarization import LLMContextSummaryConfig
@@ -587,6 +587,23 @@ class LLMMessagesUpdateFrame(DataFrame):
run_llm: Optional[bool] = None
@dataclass
class LLMMessagesTransformFrame(DataFrame):
"""Frame containing a transform function to modify the current context's LLM messages.
A frame containing a transform function that takes the context's current list
of LLM messages and returns a modified list.
Parameters:
transform: A function that takes a list of messages and returns a
modified list.
run_llm: Whether the context update should be sent to the LLM.
"""
transform: Callable[[List["LLMContextMessage"]], List["LLMContextMessage"]]
run_llm: Optional[bool] = None
@dataclass
class LLMSetToolsFrame(DataFrame):
"""Frame containing tools for LLM function calling.

View File

@@ -19,7 +19,7 @@ import base64
import io
import wave
from dataclasses import dataclass
from typing import Any, List, Optional, TypeAlias, Union
from typing import Any, Callable, List, Optional, TypeAlias, Union
from loguru import logger
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
@@ -266,6 +266,17 @@ class LLMContext:
"""
self._messages[:] = messages
def transform_messages(
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
):
"""Transform the current messages using the provided function.
Args:
transform: A function that takes the current list of messages and returns
a modified list of messages to set in the context.
"""
self.set_messages(transform(self._messages))
def set_tools(self, tools: ToolsSchema | NotGiven = NOT_GIVEN):
"""Set the available tools for the LLM.

View File

@@ -16,7 +16,7 @@ import json
import warnings
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Set, Type
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type
from loguru import logger
@@ -42,6 +42,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesTransformFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
LLMSetToolChoiceFrame,
@@ -315,6 +316,17 @@ class LLMContextAggregator(FrameProcessor):
"""
self._context.set_messages(messages)
def transform_messages(
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
):
"""Transform the context messages using a provided function.
Args:
transform: A function that takes the current list of messages and returns
a modified list of messages to set in the context.
"""
self._context.transform_messages(transform)
def set_tools(self, tools: ToolsSchema | NotGiven):
"""Set tools in the context.
@@ -512,6 +524,8 @@ class LLMUserAggregator(LLMContextAggregator):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMMessagesTransformFrame):
await self._handle_llm_messages_transform(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
# Push the LLMSetToolsFrame as well, since speech-to-speech LLM
@@ -635,6 +649,11 @@ class LLMUserAggregator(LLMContextAggregator):
if frame.run_llm:
await self.push_context_frame()
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
self.transform_messages(frame.transform)
if frame.run_llm:
await self.push_context_frame()
async def _handle_transcription(self, frame: TranscriptionFrame):
text = frame.text
@@ -923,6 +942,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
await self._handle_llm_messages_append(frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
await self._handle_llm_messages_update(frame)
elif isinstance(frame, LLMMessagesTransformFrame):
await self._handle_llm_messages_transform(frame)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
elif isinstance(frame, LLMSetToolChoiceFrame):
@@ -982,6 +1003,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
self.transform_messages(frame.transform)
if frame.run_llm:
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_interruptions(self, frame: InterruptionFrame):
await self._trigger_assistant_turn_stopped()
await self.reset()

View File

@@ -127,6 +127,7 @@ async def run_test(
expected_down_frames: Optional[Sequence[type]] = None,
expected_up_frames: Optional[Sequence[type]] = None,
frames_to_send: Sequence[Frame],
frames_to_send_direction: FrameDirection = FrameDirection.DOWNSTREAM,
ignore_start: bool = True,
observers: Optional[List[BaseObserver]] = None,
pipeline_params: Optional[PipelineParams] = None,
@@ -144,6 +145,9 @@ async def run_test(
expected_down_frames: Expected frame types flowing downstream (optional).
expected_up_frames: Expected frame types flowing upstream (optional).
frames_to_send: Sequence of frames to send through the processor.
frames_to_send_direction: Direction to send frames_to_send. Downstream
frames are pushed from the beginning of the pipeline, upstream frames
from the end. Defaults to DOWNSTREAM.
ignore_start: Whether to ignore StartFrames in frame validation.
observers: Optional list of observers to attach to the pipeline.
pipeline_params: Optional pipeline parameters.
@@ -188,7 +192,7 @@ async def run_test(
if isinstance(frame, SleepFrame):
await asyncio.sleep(frame.sleep)
else:
await task.queue_frame(frame)
await task.queue_frame(frame, frames_to_send_direction)
if send_end_frame:
await task.queue_frame(EndFrame())

View File

@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesTransformFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
LLMTextFrame,
@@ -48,6 +49,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMUserAggregator,
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.tests.utils import SleepFrame, run_test
from pipecat.turns.user_mute import (
FirstSpeechUserMuteStrategy,
@@ -90,9 +92,13 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
]
)
]
expected_down_frames = [
SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert context.messages[0]["content"] == "Hi there!"
@@ -133,9 +139,13 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
]
)
]
expected_down_frames = [
SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert context.messages[0]["content"] == "Hi there!"
@@ -180,6 +190,56 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
assert context.messages[0]["content"] == "You are a helpful assistant."
assert context.messages[1]["content"] == "Hello!"
async def test_llm_messages_transform(self):
context = LLMContext()
# Set up initial messages
context.set_messages(
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]
)
pipeline = Pipeline([LLMUserAggregator(context)])
# Transform that keeps only user messages
def keep_user_messages(messages):
return [m for m in messages if m["role"] == "user"]
frames_to_send = [LLMMessagesTransformFrame(transform=keep_user_messages)]
expected_down_frames = [
SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert len(context.messages) == 2
assert context.messages[0]["content"] == "Hello"
assert context.messages[1]["content"] == "How are you?"
async def test_llm_messages_transform_run(self):
context = LLMContext()
# Set up initial messages
context.set_messages([{"role": "user", "content": "Hello"}])
pipeline = Pipeline([LLMUserAggregator(context)])
# Transform that modifies the content
def uppercase_content(messages):
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
frames_to_send = [LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)]
expected_down_frames = [SpeechControlParamsFrame, LLMContextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert context.messages[0]["content"] == "HELLO"
async def test_default_user_turn_strategies(self):
context = LLMContext()
user_aggregator = LLMUserAggregator(
@@ -957,6 +1017,149 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(stop_messages), 1)
self.assertEqual(stop_messages[0].content, "")
async def test_llm_run(self):
context = LLMContext()
aggregator = LLMAssistantAggregator(context)
expected_up_frames = [LLMContextFrame]
await run_test(
aggregator,
frames_to_send=[LLMRunFrame()],
frames_to_send_direction=FrameDirection.UPSTREAM,
expected_up_frames=expected_up_frames,
)
async def test_llm_messages_append(self):
context = LLMContext()
aggregator = LLMAssistantAggregator(context)
await run_test(
aggregator,
frames_to_send=[
LLMMessagesAppendFrame(
messages=[
{
"role": "user",
"content": "Hi there!",
}
]
)
],
frames_to_send_direction=FrameDirection.UPSTREAM,
expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False
)
assert context.messages[0]["content"] == "Hi there!"
async def test_llm_messages_append_run(self):
context = LLMContext()
aggregator = LLMAssistantAggregator(context)
expected_up_frames = [LLMContextFrame]
await run_test(
aggregator,
frames_to_send=[
LLMMessagesAppendFrame(
messages=[
{
"role": "user",
"content": "Hi there!",
}
],
run_llm=True,
)
],
frames_to_send_direction=FrameDirection.UPSTREAM,
expected_up_frames=expected_up_frames,
)
assert context.messages[0]["content"] == "Hi there!"
async def test_llm_messages_update(self):
context = LLMContext()
aggregator = LLMAssistantAggregator(context)
await run_test(
aggregator,
frames_to_send=[
LLMMessagesUpdateFrame(
messages=[
{
"role": "user",
"content": "Hi there!",
}
]
)
],
frames_to_send_direction=FrameDirection.UPSTREAM,
expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False
)
assert context.messages[0]["content"] == "Hi there!"
async def test_llm_messages_update_run(self):
context = LLMContext()
aggregator = LLMAssistantAggregator(context)
await run_test(
aggregator,
frames_to_send=[
LLMMessagesUpdateFrame(
messages=[
{
"role": "user",
"content": "Hi there!",
}
],
run_llm=True,
)
],
frames_to_send_direction=FrameDirection.UPSTREAM,
)
assert context.messages[0]["content"] == "Hi there!"
async def test_llm_messages_transform(self):
context = LLMContext()
context.set_messages(
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]
)
aggregator = LLMAssistantAggregator(context)
# Transform that keeps only user messages
def keep_user_messages(messages):
return [m for m in messages if m["role"] == "user"]
await run_test(
aggregator,
frames_to_send=[LLMMessagesTransformFrame(transform=keep_user_messages)],
frames_to_send_direction=FrameDirection.UPSTREAM,
expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False
)
assert len(context.messages) == 2
assert context.messages[0]["content"] == "Hello"
assert context.messages[1]["content"] == "How are you?"
async def test_llm_messages_transform_run(self):
context = LLMContext()
context.set_messages([{"role": "user", "content": "Hello"}])
aggregator = LLMAssistantAggregator(context)
# Transform that modifies the content
def uppercase_content(messages):
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
expected_up_frames = [LLMContextFrame]
await run_test(
aggregator,
frames_to_send=[LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)],
frames_to_send_direction=FrameDirection.UPSTREAM,
expected_up_frames=expected_up_frames,
)
assert context.messages[0]["content"] == "HELLO"
if __name__ == "__main__":
unittest.main()