Add a LLMMessagesTransformFrame to facilitate programmatically editing context in a frame-based way.
The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages *at that point in time*, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.
This commit is contained in:
3
changelog/4231.added.md
Normal file
3
changelog/4231.added.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Added `LLMMessagesTransformFrame` to facilitate programmatically editing context in a frame-based way.
|
||||
|
||||
The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages _at that point in time_, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.
|
||||
@@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.services.settings import ServiceSettings
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummaryConfig
|
||||
@@ -587,6 +587,25 @@ class LLMMessagesUpdateFrame(DataFrame):
|
||||
run_llm: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesTransformFrame(DataFrame):
|
||||
"""Frame containing a transform function to modify the current context's LLM messages.
|
||||
|
||||
A frame containing a transform function that takes the context's current list
|
||||
of LLM messages and returns a modified list.
|
||||
|
||||
Only compatible with LLMContext and not the deprecated OpenAILLMContext.
|
||||
|
||||
Parameters:
|
||||
transform: A function that takes a list of messages and returns a
|
||||
modified list.
|
||||
run_llm: Whether the context update should be sent to the LLM.
|
||||
"""
|
||||
|
||||
transform: Callable[[List["LLMContextMessage"]], List["LLMContextMessage"]]
|
||||
run_llm: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSetToolsFrame(DataFrame):
|
||||
"""Frame containing tools for LLM function calling.
|
||||
|
||||
@@ -19,7 +19,7 @@ import base64
|
||||
import io
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, TypeAlias, Union
|
||||
from typing import Any, Callable, List, Optional, TypeAlias, Union
|
||||
|
||||
from loguru import logger
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
@@ -266,6 +266,17 @@ class LLMContext:
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
|
||||
def transform_messages(
|
||||
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
|
||||
):
|
||||
"""Transform the current messages using the provided function.
|
||||
|
||||
Args:
|
||||
transform: A function that takes the current list of messages and returns
|
||||
a modified list of messages to set in the context.
|
||||
"""
|
||||
self.set_messages(transform(self._messages))
|
||||
|
||||
def set_tools(self, tools: ToolsSchema | NotGiven = NOT_GIVEN):
|
||||
"""Set the available tools for the LLM.
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import json
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Type
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -42,6 +42,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesTransformFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
@@ -315,6 +316,17 @@ class LLMContextAggregator(FrameProcessor):
|
||||
"""
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def transform_messages(
|
||||
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
|
||||
):
|
||||
"""Transform the context messages using a provided function.
|
||||
|
||||
Args:
|
||||
transform: A function that takes the current list of messages and returns
|
||||
a modified list of messages to set in the context.
|
||||
"""
|
||||
self._context.transform_messages(transform)
|
||||
|
||||
def set_tools(self, tools: ToolsSchema | NotGiven):
|
||||
"""Set tools in the context.
|
||||
|
||||
@@ -512,6 +524,8 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMMessagesTransformFrame):
|
||||
await self._handle_llm_messages_transform(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
# Push the LLMSetToolsFrame as well, since speech-to-speech LLM
|
||||
@@ -635,6 +649,11 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
|
||||
self.transform_messages(frame.transform)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
text = frame.text
|
||||
|
||||
@@ -923,6 +942,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMMessagesTransformFrame):
|
||||
await self._handle_llm_messages_transform(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
@@ -982,6 +1003,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
|
||||
self.transform_messages(frame.transform)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
await self.reset()
|
||||
|
||||
@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesTransformFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
LLMTextFrame,
|
||||
@@ -180,6 +181,56 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
assert context.messages[0]["content"] == "You are a helpful assistant."
|
||||
assert context.messages[1]["content"] == "Hello!"
|
||||
|
||||
async def test_llm_messages_transform(self):
|
||||
context = LLMContext()
|
||||
# Set up initial messages
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
)
|
||||
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
# Transform that keeps only user messages
|
||||
def keep_user_messages(messages):
|
||||
return [m for m in messages if m["role"] == "user"]
|
||||
|
||||
frames_to_send = [LLMMessagesTransformFrame(transform=keep_user_messages)]
|
||||
expected_down_frames = [
|
||||
SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert len(context.messages) == 2
|
||||
assert context.messages[0]["content"] == "Hello"
|
||||
assert context.messages[1]["content"] == "How are you?"
|
||||
|
||||
async def test_llm_messages_transform_run(self):
|
||||
context = LLMContext()
|
||||
# Set up initial messages
|
||||
context.set_messages([{"role": "user", "content": "Hello"}])
|
||||
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
# Transform that modifies the content
|
||||
def uppercase_content(messages):
|
||||
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
|
||||
|
||||
frames_to_send = [LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)]
|
||||
expected_down_frames = [SpeechControlParamsFrame, LLMContextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "HELLO"
|
||||
|
||||
async def test_default_user_turn_strategies(self):
|
||||
context = LLMContext()
|
||||
user_aggregator = LLMUserAggregator(
|
||||
|
||||
Reference in New Issue
Block a user