Merge pull request #4231 from pipecat-ai/pk/llm-messages-transform-frame
Add a `LLMMessagesTransformFrame` to facilitate programmatically edit…
This commit is contained in:
3
changelog/4231.added.md
Normal file
3
changelog/4231.added.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Added `LLMMessagesTransformFrame` to facilitate programmatically editing context in a frame-based way.
|
||||
|
||||
The previous approach required the caller to directly grab a reference to the context object, grab a "snapshot" of its messages _at that point in time_, transform the messages, and then push an `LLMMessagesUpdateFrame` with the transformed messages. This approach can lead to problems: what if there had already been a change to the context queued in the pipeline? The transformed messages would simply overwrite it without consideration.
|
||||
@@ -38,7 +38,7 @@ from pipecat.utils.time import nanoseconds_to_str
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.services.settings import ServiceSettings
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummaryConfig
|
||||
@@ -587,6 +587,23 @@ class LLMMessagesUpdateFrame(DataFrame):
|
||||
run_llm: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesTransformFrame(DataFrame):
|
||||
"""Frame containing a transform function to modify the current context's LLM messages.
|
||||
|
||||
A frame containing a transform function that takes the context's current list
|
||||
of LLM messages and returns a modified list.
|
||||
|
||||
Parameters:
|
||||
transform: A function that takes a list of messages and returns a
|
||||
modified list.
|
||||
run_llm: Whether the context update should be sent to the LLM.
|
||||
"""
|
||||
|
||||
transform: Callable[[List["LLMContextMessage"]], List["LLMContextMessage"]]
|
||||
run_llm: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSetToolsFrame(DataFrame):
|
||||
"""Frame containing tools for LLM function calling.
|
||||
|
||||
@@ -19,7 +19,7 @@ import base64
|
||||
import io
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, TypeAlias, Union
|
||||
from typing import Any, Callable, List, Optional, TypeAlias, Union
|
||||
|
||||
from loguru import logger
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
@@ -266,6 +266,17 @@ class LLMContext:
|
||||
"""
|
||||
self._messages[:] = messages
|
||||
|
||||
def transform_messages(
|
||||
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
|
||||
):
|
||||
"""Transform the current messages using the provided function.
|
||||
|
||||
Args:
|
||||
transform: A function that takes the current list of messages and returns
|
||||
a modified list of messages to set in the context.
|
||||
"""
|
||||
self.set_messages(transform(self._messages))
|
||||
|
||||
def set_tools(self, tools: ToolsSchema | NotGiven = NOT_GIVEN):
|
||||
"""Set the available tools for the LLM.
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import json
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Type
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -42,6 +42,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesTransformFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
@@ -315,6 +316,17 @@ class LLMContextAggregator(FrameProcessor):
|
||||
"""
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def transform_messages(
|
||||
self, transform: Callable[[List[LLMContextMessage]], List[LLMContextMessage]]
|
||||
):
|
||||
"""Transform the context messages using a provided function.
|
||||
|
||||
Args:
|
||||
transform: A function that takes the current list of messages and returns
|
||||
a modified list of messages to set in the context.
|
||||
"""
|
||||
self._context.transform_messages(transform)
|
||||
|
||||
def set_tools(self, tools: ToolsSchema | NotGiven):
|
||||
"""Set tools in the context.
|
||||
|
||||
@@ -512,6 +524,8 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMMessagesTransformFrame):
|
||||
await self._handle_llm_messages_transform(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
# Push the LLMSetToolsFrame as well, since speech-to-speech LLM
|
||||
@@ -635,6 +649,11 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
|
||||
self.transform_messages(frame.transform)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
text = frame.text
|
||||
|
||||
@@ -923,6 +942,8 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self._handle_llm_messages_append(frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self._handle_llm_messages_update(frame)
|
||||
elif isinstance(frame, LLMMessagesTransformFrame):
|
||||
await self._handle_llm_messages_transform(frame)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
@@ -982,6 +1003,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame):
|
||||
self.transform_messages(frame.transform)
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
await self.reset()
|
||||
|
||||
@@ -127,6 +127,7 @@ async def run_test(
|
||||
expected_down_frames: Optional[Sequence[type]] = None,
|
||||
expected_up_frames: Optional[Sequence[type]] = None,
|
||||
frames_to_send: Sequence[Frame],
|
||||
frames_to_send_direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
ignore_start: bool = True,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
pipeline_params: Optional[PipelineParams] = None,
|
||||
@@ -144,6 +145,9 @@ async def run_test(
|
||||
expected_down_frames: Expected frame types flowing downstream (optional).
|
||||
expected_up_frames: Expected frame types flowing upstream (optional).
|
||||
frames_to_send: Sequence of frames to send through the processor.
|
||||
frames_to_send_direction: Direction to send frames_to_send. Downstream
|
||||
frames are pushed from the beginning of the pipeline, upstream frames
|
||||
from the end. Defaults to DOWNSTREAM.
|
||||
ignore_start: Whether to ignore StartFrames in frame validation.
|
||||
observers: Optional list of observers to attach to the pipeline.
|
||||
pipeline_params: Optional pipeline parameters.
|
||||
@@ -188,7 +192,7 @@ async def run_test(
|
||||
if isinstance(frame, SleepFrame):
|
||||
await asyncio.sleep(frame.sleep)
|
||||
else:
|
||||
await task.queue_frame(frame)
|
||||
await task.queue_frame(frame, frames_to_send_direction)
|
||||
|
||||
if send_end_frame:
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesTransformFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
LLMTextFrame,
|
||||
@@ -48,6 +49,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMUserAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.turns.user_mute import (
|
||||
FirstSpeechUserMuteStrategy,
|
||||
@@ -90,9 +92,13 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
)
|
||||
]
|
||||
expected_down_frames = [
|
||||
SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
@@ -133,9 +139,13 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
)
|
||||
]
|
||||
expected_down_frames = [
|
||||
SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
@@ -180,6 +190,56 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
assert context.messages[0]["content"] == "You are a helpful assistant."
|
||||
assert context.messages[1]["content"] == "Hello!"
|
||||
|
||||
async def test_llm_messages_transform(self):
|
||||
context = LLMContext()
|
||||
# Set up initial messages
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
)
|
||||
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
# Transform that keeps only user messages
|
||||
def keep_user_messages(messages):
|
||||
return [m for m in messages if m["role"] == "user"]
|
||||
|
||||
frames_to_send = [LLMMessagesTransformFrame(transform=keep_user_messages)]
|
||||
expected_down_frames = [
|
||||
SpeechControlParamsFrame # no LLMContextFrame expected, run_llm defaults to False
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert len(context.messages) == 2
|
||||
assert context.messages[0]["content"] == "Hello"
|
||||
assert context.messages[1]["content"] == "How are you?"
|
||||
|
||||
async def test_llm_messages_transform_run(self):
|
||||
context = LLMContext()
|
||||
# Set up initial messages
|
||||
context.set_messages([{"role": "user", "content": "Hello"}])
|
||||
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
# Transform that modifies the content
|
||||
def uppercase_content(messages):
|
||||
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
|
||||
|
||||
frames_to_send = [LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)]
|
||||
expected_down_frames = [SpeechControlParamsFrame, LLMContextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "HELLO"
|
||||
|
||||
async def test_default_user_turn_strategies(self):
|
||||
context = LLMContext()
|
||||
user_aggregator = LLMUserAggregator(
|
||||
@@ -957,6 +1017,149 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(len(stop_messages), 1)
|
||||
self.assertEqual(stop_messages[0].content, "")
|
||||
|
||||
async def test_llm_run(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
expected_up_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[LLMRunFrame()],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
|
||||
async def test_llm_messages_append(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
]
|
||||
)
|
||||
],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_append_run(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
expected_up_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
],
|
||||
run_llm=True,
|
||||
)
|
||||
],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_update(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[
|
||||
LLMMessagesUpdateFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
]
|
||||
)
|
||||
],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_update_run(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[
|
||||
LLMMessagesUpdateFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
],
|
||||
run_llm=True,
|
||||
)
|
||||
],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_transform(self):
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
)
|
||||
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
# Transform that keeps only user messages
|
||||
def keep_user_messages(messages):
|
||||
return [m for m in messages if m["role"] == "user"]
|
||||
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[LLMMessagesTransformFrame(transform=keep_user_messages)],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=[], # no LLMContextFrame expected, run_llm defaults to False
|
||||
)
|
||||
assert len(context.messages) == 2
|
||||
assert context.messages[0]["content"] == "Hello"
|
||||
assert context.messages[1]["content"] == "How are you?"
|
||||
|
||||
async def test_llm_messages_transform_run(self):
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "user", "content": "Hello"}])
|
||||
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
# Transform that modifies the content
|
||||
def uppercase_content(messages):
|
||||
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
|
||||
|
||||
expected_up_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "HELLO"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user