add new LLMFullResponseAggregator

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-27 15:23:16 -08:00
parent 1c92fab1fb
commit 8db9d16174
3 changed files with 193 additions and 0 deletions

View File

@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added new `LLMFullResponseAggregator` to aggregate full LLM completions. At
every completion the `on_completion` event handler is triggered.
- Added a new frame, `RTVIServerMessageFrame`, and RTVI message
`RTVIServerMessage` which provides a generic mechanism for sending custom
messages from server to client. The `RTVIServerMessageFrame` is processed by

View File

@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
LLMMessagesFrame,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
LLMTextFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
@@ -36,6 +37,59 @@ from pipecat.processors.aggregators.openai_llm_context import (
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class LLMFullResponseAggregator(FrameProcessor):
"""This is an LLM aggregator that aggregates a full LLM completion. It
aggregates LLM text frames (tokens) received between
`LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`. Every full
completion is returned via the "on_completion" event handler:
@aggregator.event_handler("on_completion")
async def on_completion(
aggregator: LLMFullResponseAggregator,
completion: str,
completed: bool,
)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._aggregation = ""
self._started = False
self._register_event_handler("on_completion")
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._call_event_handler("on_completion", self._aggregation, False)
self._aggregation = ""
self._started = False
elif isinstance(frame, LLMFullResponseStartFrame):
await self._handle_llm_start(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame)
elif isinstance(frame, LLMTextFrame):
await self._handle_llm_text(frame)
await self.push_frame(frame, direction)
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
self._started = True
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
await self._call_event_handler("on_completion", self._aggregation, True)
self._started = False
self._aggregation = ""
async def _handle_llm_text(self, frame: TextFrame):
if not self._started:
return
self._aggregation += frame.text
class BaseLLMResponseAggregator(FrameProcessor):
"""This is the base class for all LLM response aggregators. These
aggregators process incoming frames and aggregate content until they are

136
tests/test_llm_response.py Normal file
View File

@@ -0,0 +1,136 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartInterruptionFrame,
)
from pipecat.processors.aggregators.llm_response import LLMFullResponseAggregator
from pipecat.tests.utils import SleepFrame, run_test
class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
async def test_empty(self):
completion_ok = False
aggregator = LLMFullResponseAggregator()
@aggregator.event_handler("on_completion")
async def on_completion(aggregator, completion, completed):
nonlocal completion_ok
completion_ok = completion == "" and completed
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
expected_down_frames = [LLMFullResponseStartFrame, LLMFullResponseEndFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert completion_ok
async def test_simple(self):
completion_ok = False
aggregator = LLMFullResponseAggregator()
@aggregator.event_handler("on_completion")
async def on_completion(aggregator, completion, completed):
nonlocal completion_ok
completion_ok = completion == "Hello from Pipecat!" and completed
frames_to_send = [
LLMFullResponseStartFrame(),
LLMTextFrame("Hello from Pipecat!"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [LLMFullResponseStartFrame, LLMTextFrame, LLMFullResponseEndFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert completion_ok
async def test_multiple(self):
completion_ok = False
aggregator = LLMFullResponseAggregator()
@aggregator.event_handler("on_completion")
async def on_completion(aggregator, completion, completed):
nonlocal completion_ok
completion_ok = completion == "Hello from Pipecat!" and completed
frames_to_send = [
LLMFullResponseStartFrame(),
LLMTextFrame("Hello "),
LLMTextFrame("from "),
LLMTextFrame("Pipecat!"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [
LLMFullResponseStartFrame,
LLMTextFrame,
LLMTextFrame,
LLMTextFrame,
LLMFullResponseEndFrame,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert completion_ok
async def test_interruption(self):
completion_ok = True
completion_result = [("Hello ", False), ("Hello there!", True)]
completion_index = 0
aggregator = LLMFullResponseAggregator()
@aggregator.event_handler("on_completion")
async def on_completion(aggregator, completion, completed):
nonlocal completion_result, completion_index, completion_ok
(completion_expected, completion_completed) = completion_result[completion_index]
completion_ok = (
completion_ok
and completion == completion_expected
and completed == completion_completed
)
completion_index += 1
frames_to_send = [
LLMFullResponseStartFrame(),
LLMTextFrame("Hello "),
SleepFrame(),
StartInterruptionFrame(),
LLMFullResponseStartFrame(),
LLMTextFrame("Hello "),
LLMTextFrame("there!"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [
LLMFullResponseStartFrame,
LLMTextFrame,
StartInterruptionFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
LLMTextFrame,
LLMFullResponseEndFrame,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert completion_ok