From 8db9d161741b7f1c639b8a4d77dd47dc32eb83ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 27 Feb 2025 15:23:16 -0800 Subject: [PATCH] add new LLMFullResponseAggregator --- CHANGELOG.md | 3 + .../processors/aggregators/llm_response.py | 54 +++++++ tests/test_llm_response.py | 136 ++++++++++++++++++ 3 files changed, 193 insertions(+) create mode 100644 tests/test_llm_response.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 422fa2246..2b4660d89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added new `LLMFullResponseAggregator` to aggregate full LLM completions. At + every completion the `on_completion` event handler is triggered. + - Added a new frame, `RTVIServerMessageFrame`, and RTVI message `RTVIServerMessage` which provides a generic mechanism for sending custom messages from server to client. The `RTVIServerMessageFrame` is processed by diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index da8ca63aa..d8582e32f 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -22,6 +22,7 @@ from pipecat.frames.frames import ( LLMMessagesFrame, LLMMessagesUpdateFrame, LLMSetToolsFrame, + LLMTextFrame, StartFrame, StartInterruptionFrame, TextFrame, @@ -36,6 +37,59 @@ from pipecat.processors.aggregators.openai_llm_context import ( from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +class LLMFullResponseAggregator(FrameProcessor): + """This is an LLM aggregator that aggregates a full LLM completion. It + aggregates LLM text frames (tokens) received between + `LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`. Every full + completion is returned via the "on_completion" event handler: + + @aggregator.event_handler("on_completion") + async def on_completion( + aggregator: LLMFullResponseAggregator, + completion: str, + completed: bool, + ) + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self._aggregation = "" + self._started = False + + self._register_event_handler("on_completion") + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartInterruptionFrame): + await self._call_event_handler("on_completion", self._aggregation, False) + self._aggregation = "" + self._started = False + elif isinstance(frame, LLMFullResponseStartFrame): + await self._handle_llm_start(frame) + elif isinstance(frame, LLMFullResponseEndFrame): + await self._handle_llm_end(frame) + elif isinstance(frame, LLMTextFrame): + await self._handle_llm_text(frame) + + await self.push_frame(frame, direction) + + async def _handle_llm_start(self, _: LLMFullResponseStartFrame): + self._started = True + + async def _handle_llm_end(self, _: LLMFullResponseEndFrame): + await self._call_event_handler("on_completion", self._aggregation, True) + self._started = False + self._aggregation = "" + + async def _handle_llm_text(self, frame: TextFrame): + if not self._started: + return + self._aggregation += frame.text + + class BaseLLMResponseAggregator(FrameProcessor): """This is the base class for all LLM response aggregators. These aggregators process incoming frames and aggregate content until they are diff --git a/tests/test_llm_response.py b/tests/test_llm_response.py new file mode 100644 index 000000000..93838a658 --- /dev/null +++ b/tests/test_llm_response.py @@ -0,0 +1,136 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.frames.frames import ( + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMTextFrame, + StartInterruptionFrame, +) +from pipecat.processors.aggregators.llm_response import LLMFullResponseAggregator +from pipecat.tests.utils import SleepFrame, run_test + + +class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase): + async def test_empty(self): + completion_ok = False + + aggregator = LLMFullResponseAggregator() + + @aggregator.event_handler("on_completion") + async def on_completion(aggregator, completion, completed): + nonlocal completion_ok + completion_ok = completion == "" and completed + + frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()] + expected_down_frames = [LLMFullResponseStartFrame, LLMFullResponseEndFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert completion_ok + + async def test_simple(self): + completion_ok = False + + aggregator = LLMFullResponseAggregator() + + @aggregator.event_handler("on_completion") + async def on_completion(aggregator, completion, completed): + nonlocal completion_ok + completion_ok = completion == "Hello from Pipecat!" and completed + + frames_to_send = [ + LLMFullResponseStartFrame(), + LLMTextFrame("Hello from Pipecat!"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [LLMFullResponseStartFrame, LLMTextFrame, LLMFullResponseEndFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert completion_ok + + async def test_multiple(self): + completion_ok = False + + aggregator = LLMFullResponseAggregator() + + @aggregator.event_handler("on_completion") + async def on_completion(aggregator, completion, completed): + nonlocal completion_ok + completion_ok = completion == "Hello from Pipecat!" and completed + + frames_to_send = [ + LLMFullResponseStartFrame(), + LLMTextFrame("Hello "), + LLMTextFrame("from "), + LLMTextFrame("Pipecat!"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [ + LLMFullResponseStartFrame, + LLMTextFrame, + LLMTextFrame, + LLMTextFrame, + LLMFullResponseEndFrame, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert completion_ok + + async def test_interruption(self): + completion_ok = True + + completion_result = [("Hello ", False), ("Hello there!", True)] + completion_index = 0 + + aggregator = LLMFullResponseAggregator() + + @aggregator.event_handler("on_completion") + async def on_completion(aggregator, completion, completed): + nonlocal completion_result, completion_index, completion_ok + (completion_expected, completion_completed) = completion_result[completion_index] + completion_ok = ( + completion_ok + and completion == completion_expected + and completed == completion_completed + ) + completion_index += 1 + + frames_to_send = [ + LLMFullResponseStartFrame(), + LLMTextFrame("Hello "), + SleepFrame(), + StartInterruptionFrame(), + LLMFullResponseStartFrame(), + LLMTextFrame("Hello "), + LLMTextFrame("there!"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [ + LLMFullResponseStartFrame, + LLMTextFrame, + StartInterruptionFrame, + LLMFullResponseStartFrame, + LLMTextFrame, + LLMTextFrame, + LLMFullResponseEndFrame, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert completion_ok