Test LLMAssistantAggregator handling of upstream message frames
Add tests for LLMRunFrame, LLMMessagesAppendFrame, LLMMessagesUpdateFrame, and LLMMessagesTransformFrame sent upstream to LLMAssistantAggregator, mirroring the existing LLMUserAggregator downstream tests. Add frames_to_send_direction param to run_test helper to support this.
This commit is contained in:
@@ -127,6 +127,7 @@ async def run_test(
|
||||
expected_down_frames: Optional[Sequence[type]] = None,
|
||||
expected_up_frames: Optional[Sequence[type]] = None,
|
||||
frames_to_send: Sequence[Frame],
|
||||
frames_to_send_direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
ignore_start: bool = True,
|
||||
observers: Optional[List[BaseObserver]] = None,
|
||||
pipeline_params: Optional[PipelineParams] = None,
|
||||
@@ -144,6 +145,9 @@ async def run_test(
|
||||
expected_down_frames: Expected frame types flowing downstream (optional).
|
||||
expected_up_frames: Expected frame types flowing upstream (optional).
|
||||
frames_to_send: Sequence of frames to send through the processor.
|
||||
frames_to_send_direction: Direction to send frames_to_send. Downstream
|
||||
frames are pushed from the beginning of the pipeline, upstream frames
|
||||
from the end. Defaults to DOWNSTREAM.
|
||||
ignore_start: Whether to ignore StartFrames in frame validation.
|
||||
observers: Optional list of observers to attach to the pipeline.
|
||||
pipeline_params: Optional pipeline parameters.
|
||||
@@ -188,7 +192,7 @@ async def run_test(
|
||||
if isinstance(frame, SleepFrame):
|
||||
await asyncio.sleep(frame.sleep)
|
||||
else:
|
||||
await task.queue_frame(frame)
|
||||
await task.queue_frame(frame, frames_to_send_direction)
|
||||
|
||||
if send_end_frame:
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
@@ -49,6 +49,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMUserAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.turns.user_mute import (
|
||||
FirstSpeechUserMuteStrategy,
|
||||
@@ -1008,6 +1009,146 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(len(stop_messages), 1)
|
||||
self.assertEqual(stop_messages[0].content, "")
|
||||
|
||||
async def test_llm_run(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
expected_up_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[LLMRunFrame()],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
|
||||
async def test_llm_messages_append(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
]
|
||||
)
|
||||
],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_append_run(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
expected_up_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
],
|
||||
run_llm=True,
|
||||
)
|
||||
],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_update(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[
|
||||
LLMMessagesUpdateFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
]
|
||||
)
|
||||
],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_update_run(self):
|
||||
context = LLMContext()
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[
|
||||
LLMMessagesUpdateFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
],
|
||||
run_llm=True,
|
||||
)
|
||||
],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_transform(self):
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
)
|
||||
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
# Transform that keeps only user messages
|
||||
def keep_user_messages(messages):
|
||||
return [m for m in messages if m["role"] == "user"]
|
||||
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[LLMMessagesTransformFrame(transform=keep_user_messages)],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
)
|
||||
assert len(context.messages) == 2
|
||||
assert context.messages[0]["content"] == "Hello"
|
||||
assert context.messages[1]["content"] == "How are you?"
|
||||
|
||||
async def test_llm_messages_transform_run(self):
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "user", "content": "Hello"}])
|
||||
|
||||
aggregator = LLMAssistantAggregator(context)
|
||||
|
||||
# Transform that modifies the content
|
||||
def uppercase_content(messages):
|
||||
return [{"role": m["role"], "content": m["content"].upper()} for m in messages]
|
||||
|
||||
expected_up_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=[LLMMessagesTransformFrame(transform=uppercase_content, run_llm=True)],
|
||||
frames_to_send_direction=FrameDirection.UPSTREAM,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "HELLO"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user