add new LLMFullResponseAggregator
This commit is contained in:
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added new `LLMFullResponseAggregator` to aggregate full LLM completions. At
|
||||
every completion the `on_completion` event handler is triggered.
|
||||
|
||||
- Added a new frame, `RTVIServerMessageFrame`, and RTVI message
|
||||
`RTVIServerMessage` which provides a generic mechanism for sending custom
|
||||
messages from server to client. The `RTVIServerMessageFrame` is processed by
|
||||
|
||||
@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
@@ -36,6 +37,59 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""This is an LLM aggregator that aggregates a full LLM completion. It
|
||||
aggregates LLM text frames (tokens) received between
|
||||
`LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`. Every full
|
||||
completion is returned via the "on_completion" event handler:
|
||||
|
||||
@aggregator.event_handler("on_completion")
|
||||
async def on_completion(
|
||||
aggregator: LLMFullResponseAggregator,
|
||||
completion: str,
|
||||
completed: bool,
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._aggregation = ""
|
||||
self._started = False
|
||||
|
||||
self._register_event_handler("on_completion")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._call_event_handler("on_completion", self._aggregation, False)
|
||||
self._aggregation = ""
|
||||
self._started = False
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_start(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame)
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
await self._handle_llm_text(frame)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started = True
|
||||
|
||||
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||
await self._call_event_handler("on_completion", self._aggregation, True)
|
||||
self._started = False
|
||||
self._aggregation = ""
|
||||
|
||||
async def _handle_llm_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
return
|
||||
self._aggregation += frame.text
|
||||
|
||||
|
||||
class BaseLLMResponseAggregator(FrameProcessor):
|
||||
"""This is the base class for all LLM response aggregators. These
|
||||
aggregators process incoming frames and aggregate content until they are
|
||||
|
||||
136
tests/test_llm_response.py
Normal file
136
tests/test_llm_response.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartInterruptionFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import LLMFullResponseAggregator
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_empty(self):
|
||||
completion_ok = False
|
||||
|
||||
aggregator = LLMFullResponseAggregator()
|
||||
|
||||
@aggregator.event_handler("on_completion")
|
||||
async def on_completion(aggregator, completion, completed):
|
||||
nonlocal completion_ok
|
||||
completion_ok = completion == "" and completed
|
||||
|
||||
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
|
||||
expected_down_frames = [LLMFullResponseStartFrame, LLMFullResponseEndFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert completion_ok
|
||||
|
||||
async def test_simple(self):
|
||||
completion_ok = False
|
||||
|
||||
aggregator = LLMFullResponseAggregator()
|
||||
|
||||
@aggregator.event_handler("on_completion")
|
||||
async def on_completion(aggregator, completion, completed):
|
||||
nonlocal completion_ok
|
||||
completion_ok = completion == "Hello from Pipecat!" and completed
|
||||
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
LLMTextFrame("Hello from Pipecat!"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [LLMFullResponseStartFrame, LLMTextFrame, LLMFullResponseEndFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert completion_ok
|
||||
|
||||
async def test_multiple(self):
|
||||
completion_ok = False
|
||||
|
||||
aggregator = LLMFullResponseAggregator()
|
||||
|
||||
@aggregator.event_handler("on_completion")
|
||||
async def on_completion(aggregator, completion, completed):
|
||||
nonlocal completion_ok
|
||||
completion_ok = completion == "Hello from Pipecat!" and completed
|
||||
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
LLMTextFrame("Hello "),
|
||||
LLMTextFrame("from "),
|
||||
LLMTextFrame("Pipecat!"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
LLMTextFrame,
|
||||
LLMTextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert completion_ok
|
||||
|
||||
async def test_interruption(self):
|
||||
completion_ok = True
|
||||
|
||||
completion_result = [("Hello ", False), ("Hello there!", True)]
|
||||
completion_index = 0
|
||||
|
||||
aggregator = LLMFullResponseAggregator()
|
||||
|
||||
@aggregator.event_handler("on_completion")
|
||||
async def on_completion(aggregator, completion, completed):
|
||||
nonlocal completion_result, completion_index, completion_ok
|
||||
(completion_expected, completion_completed) = completion_result[completion_index]
|
||||
completion_ok = (
|
||||
completion_ok
|
||||
and completion == completion_expected
|
||||
and completed == completion_completed
|
||||
)
|
||||
completion_index += 1
|
||||
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
LLMTextFrame("Hello "),
|
||||
SleepFrame(),
|
||||
StartInterruptionFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
LLMTextFrame("Hello "),
|
||||
LLMTextFrame("there!"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartInterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
LLMTextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert completion_ok
|
||||
Reference in New Issue
Block a user