From d1ee851a65a7be3c4a5a6f22b8c0f83b70a1e8bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 11 Feb 2025 22:12:48 -0800 Subject: [PATCH 01/22] tests: rename some variables to make things clearer --- tests/test_aggregators.py | 8 ++++---- tests/test_filters.py | 32 ++++++++++++++++---------------- tests/test_pipeline.py | 20 ++++++++++---------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/test_aggregators.py b/tests/test_aggregators.py index ab7210e49..48650977e 100644 --- a/tests/test_aggregators.py +++ b/tests/test_aggregators.py @@ -29,12 +29,12 @@ class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase): for word in sentence.split(" "): frames_to_send.append(TextFrame(text=word + " ")) - expected_returned_frames = [TextFrame, TextFrame, TextFrame] + expected_down_frames = [TextFrame, TextFrame, TextFrame] (received_down, _) = await run_test( aggregator, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) assert received_down[-3].text == "Hello, world. " assert received_down[-2].text == "How are you? " @@ -59,7 +59,7 @@ class TestGatedAggregator(unittest.IsolatedAsyncioTestCase): LLMFullResponseEndFrame(), ] - expected_returned_frames = [ + expected_down_frames = [ OutputImageRawFrame, LLMFullResponseStartFrame, TextFrame, @@ -72,5 +72,5 @@ class TestGatedAggregator(unittest.IsolatedAsyncioTestCase): (received_down, _) = await run_test( gated_aggregator, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) diff --git a/tests/test_filters.py b/tests/test_filters.py index 9d3b0f003..a47903232 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -25,11 +25,11 @@ class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase): async def test_identity(self): filter = IdentityFilter() frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()] - expected_returned_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] + expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) @@ -37,32 +37,32 @@ class TestFrameFilter(unittest.IsolatedAsyncioTestCase): async def test_text_frame(self): filter = FrameFilter(types=(TextFrame,)) frames_to_send = [TextFrame(text="Hello Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_end_frame(self): filter = FrameFilter(types=(EndFrame,)) frames_to_send = [EndFrame()] - expected_returned_frames = [EndFrame] + expected_down_frames = [EndFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, send_end_frame=False, ) async def test_system_frame(self): filter = FrameFilter(types=()) frames_to_send = [UserStartedSpeakingFrame()] - expected_returned_frames = [UserStartedSpeakingFrame] + expected_down_frames = [UserStartedSpeakingFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) @@ -73,11 +73,11 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase): filter = FunctionFilter(filter=passthrough) frames_to_send = [TextFrame(text="Hello Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_no_passthrough(self): @@ -86,11 +86,11 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase): filter = FunctionFilter(filter=no_passthrough) frames_to_send = [TextFrame(text="Hello Pipecat!")] - expected_returned_frames = [] + expected_down_frames = [] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) @@ -98,11 +98,11 @@ class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase): async def test_no_wake_word(self): filter = WakeCheckFilter(wake_phrases=["Hey, Pipecat"]) frames_to_send = [TranscriptionFrame(user_id="test", text="Phrase 1", timestamp="")] - expected_returned_frames = [] + expected_down_frames = [] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_wake_word(self): @@ -111,10 +111,10 @@ class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase): TranscriptionFrame(user_id="test", text="Hey, Pipecat", timestamp=""), TranscriptionFrame(user_id="test", text="Phrase 1", timestamp=""), ] - expected_returned_frames = [TranscriptionFrame, TranscriptionFrame] + expected_down_frames = [TranscriptionFrame, TranscriptionFrame] (received_down, _) = await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) assert received_down[-1].text == "Phrase 1" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index caac89c11..0aff922b2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -21,11 +21,11 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase): pipeline = Pipeline([IdentityFilter()]) frames_to_send = [TextFrame(text="Hello from Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_pipeline_multiple(self): @@ -36,22 +36,22 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase): pipeline = Pipeline([identity1, identity2, identity3]) frames_to_send = [TextFrame(text="Hello from Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_pipeline_start_metadata(self): pipeline = Pipeline([IdentityFilter()]) frames_to_send = [] - expected_returned_frames = [StartFrame] + expected_down_frames = [StartFrame] (received_down, _) = await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ignore_start=False, start_metadata={"foo": "bar"}, ) @@ -63,11 +63,11 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase): pipeline = ParallelPipeline([IdentityFilter()]) frames_to_send = [TextFrame(text="Hello from Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_parallel_multiple(self): @@ -75,11 +75,11 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase): pipeline = ParallelPipeline([IdentityFilter()], [IdentityFilter()]) frames_to_send = [TextFrame(text="Hello from Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) From 1b7dfe81260aa220749a0cef94fe1f7fe3be71ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 11 Feb 2025 22:13:07 -0800 Subject: [PATCH 02/22] tests: add a new SleepFrame The new SleepFrame allow us to control when system frames are pushed to the pipeline. --- src/pipecat/tests/utils.py | 45 +++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/src/pipecat/tests/utils.py b/src/pipecat/tests/utils.py index ff92164a4..2b78f2bef 100644 --- a/src/pipecat/tests/utils.py +++ b/src/pipecat/tests/utils.py @@ -5,6 +5,7 @@ # import asyncio +from dataclasses import dataclass from typing import Any, Awaitable, Callable, Dict, Sequence, Tuple from pipecat.frames.frames import ( @@ -12,6 +13,7 @@ from pipecat.frames.frames import ( Frame, HeartbeatFrame, StartFrame, + SystemFrame, ) from pipecat.observers.base_observer import BaseObserver from pipecat.pipeline.pipeline import Pipeline @@ -20,6 +22,15 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +@dataclass +class SleepFrame(SystemFrame): + """This frame is used by test framework to introduce some sleep time before + the next frame is pushed. This is useful to control system frames vs data or + control frames.""" + + sleep: float = 0.1 + + class HeartbeatsObserver(BaseObserver): def __init__( self, @@ -44,7 +55,11 @@ class HeartbeatsObserver(BaseObserver): class QueuedFrameProcessor(FrameProcessor): def __init__( - self, queue: asyncio.Queue, queue_direction: FrameDirection, ignore_start: bool = True + self, + *, + queue: asyncio.Queue, + queue_direction: FrameDirection, + ignore_start: bool = True, ): super().__init__() self._queue = queue @@ -72,21 +87,35 @@ async def run_test( ) -> Tuple[Sequence[Frame], Sequence[Frame]]: received_up = asyncio.Queue() received_down = asyncio.Queue() - source = QueuedFrameProcessor(received_up, FrameDirection.UPSTREAM, ignore_start) - sink = QueuedFrameProcessor(received_down, FrameDirection.DOWNSTREAM, ignore_start) + source = QueuedFrameProcessor( + queue=received_up, + queue_direction=FrameDirection.UPSTREAM, + ignore_start=ignore_start, + ) + sink = QueuedFrameProcessor( + queue=received_down, + queue_direction=FrameDirection.DOWNSTREAM, + ignore_start=ignore_start, + ) pipeline = Pipeline([source, processor, sink]) task = PipelineTask(pipeline, params=PipelineParams(start_metadata=start_metadata)) - for frame in frames_to_send: - await task.queue_frame(frame) + async def push_frames(): + # Just give a little head start to the runner. + await asyncio.sleep(0.01) + for frame in frames_to_send: + if isinstance(frame, SleepFrame): + await asyncio.sleep(frame.sleep) + else: + await task.queue_frame(frame) - if send_end_frame: - await task.queue_frame(EndFrame()) + if send_end_frame: + await task.queue_frame(EndFrame()) runner = PipelineRunner() - await runner.run(task) + await asyncio.gather(runner.run(task), push_frames()) # # Down frames From 8bdd7ed0ed2f3ceae354c290dab2d19d6b41ba41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 11 Feb 2025 22:14:27 -0800 Subject: [PATCH 03/22] tests: implement langchain tests with run_test() --- tests/test_langchain.py | 48 ++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/tests/test_langchain.py b/tests/test_langchain.py index c94bc1c01..6534f6bb0 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -10,23 +10,22 @@ from langchain.prompts import ChatPromptTemplate from langchain_core.language_models import FakeStreamingListLLM from pipecat.frames.frames import ( - EndFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, + LLMMessagesFrame, TextFrame, TranscriptionFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) from pipecat.pipeline.pipeline import Pipeline -from pipecat.pipeline.runner import PipelineRunner -from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.llm_response import ( LLMAssistantResponseAggregator, LLMUserResponseAggregator, ) from pipecat.processors.frame_processor import FrameProcessor from pipecat.processors.frameworks.langchain import LangchainProcessor +from pipecat.tests.utils import SleepFrame, run_test class TestLangchain(unittest.IsolatedAsyncioTestCase): @@ -64,31 +63,26 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase): self.mock_proc = self.MockProcessor("token_collector") tma_in = LLMUserResponseAggregator(messages) - tma_out = LLMAssistantResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages, expect_stripped_words=False) - pipeline = Pipeline( - [ - tma_in, - proc, - self.mock_proc, - tma_out, - ] + pipeline = Pipeline([tma_in, proc, self.mock_proc, tma_out]) + + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"), + SleepFrame(), + UserStoppedSpeakingFrame(), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + LLMMessagesFrame, + ] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, ) - task = PipelineTask(pipeline, PipelineParams(allow_interruptions=False)) - await task.queue_frames( - [ - UserStartedSpeakingFrame(), - TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"), - UserStoppedSpeakingFrame(), - EndFrame(), - ] - ) - - runner = PipelineRunner() - await runner.run(task) self.assertEqual("".join(self.mock_proc.token), self.expected_response) - # TODO: Address this issue - # This next one would fail with: - # AssertionError: ' H e l l o d e a r h u m a n' != 'Hello dear human' - # self.assertEqual(tma_out.messages[-1]["content"], self.expected_response) + self.assertEqual(tma_out.messages[-1]["content"], self.expected_response) From e1f2bbceb38b0cb23869684f50dcd0c1ef935ce7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 11 Feb 2025 22:16:10 -0800 Subject: [PATCH 04/22] reimplement LLM response aggregators --- CHANGELOG.md | 17 +- .../22d-natural-conversation-gemini-audio.py | 2 +- .../processors/aggregators/llm_response.py | 469 +++++++++--------- src/pipecat/services/anthropic.py | 8 +- .../services/gemini_multimodal_live/gemini.py | 4 +- src/pipecat/services/google/google.py | 8 +- src/pipecat/services/grok.py | 4 +- src/pipecat/services/openai.py | 8 +- .../services/openai_realtime_beta/context.py | 6 +- 9 files changed, 275 insertions(+), 251 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f9cd7e0c..634a740e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,10 +38,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 OpenAI-compatible interface. Also, added foundational example `14n-function-calling-perplexity.py`. -- Added `DailyTransport.update_remote_participants()`. This allows you to - update remote participant's settings, like their permissions or which of - their devices are enabled. Requires that the local participant have - participant admin permission. +- Added `DailyTransport.update_remote_participants()`. This allows you to update + remote participant's settings, like their permissions or which of their + devices are enabled. Requires that the local participant have participant + admin permission. ### Changed @@ -91,6 +91,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed multiple issue where user transcriptions where not being handled + properly. It was possible for short utterances to not trigger VAD which would + cause user transcriptions to be ignored. It was also possible for one or more + transcriptions to be generated after VAD in which case they would also be + ignored. + - Fixed an issue that was causing `BotStoppedSpeakingFrame` to be generated too late. This could then cause issues unblocking `STTMuteFilter` later than desired. @@ -283,7 +289,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `enable_recording` and `geo` parameters to `DailyRoomProperties`. -- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings to a custom AWS bucket. +- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings + to a custom AWS bucket. ### Changed diff --git a/examples/foundational/22d-natural-conversation-gemini-audio.py b/examples/foundational/22d-natural-conversation-gemini-audio.py index 860f1f574..15ee835e2 100644 --- a/examples/foundational/22d-natural-conversation-gemini-audio.py +++ b/examples/foundational/22d-natural-conversation-gemini-audio.py @@ -497,7 +497,7 @@ class UserAggregatorBuffer(LLMResponseAggregator): if isinstance(frame, UserStartedSpeakingFrame): self._transcription = "" - async def _push_aggregation(self): + async def push_aggregation(self): if self._aggregation: self._transcription = self._aggregation self._aggregation = "" diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 78db72dfb..0422524f8 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -4,9 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import List, Optional, Type +import asyncio +from abc import abstractmethod +from typing import List from pipecat.frames.frames import ( + CancelFrame, + EndFrame, Frame, InterimTranscriptionFrame, LLMFullResponseEndFrame, @@ -15,6 +19,7 @@ from pipecat.frames.frames import ( LLMMessagesFrame, LLMMessagesUpdateFrame, LLMSetToolsFrame, + StartFrame, StartInterruptionFrame, TextFrame, TranscriptionFrame, @@ -28,121 +33,81 @@ from pipecat.processors.aggregators.openai_llm_context import ( from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -class LLMResponseAggregator(FrameProcessor): +class BaseLLMResponseAggregator(FrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @property + @abstractmethod + def messages(self) -> List[dict]: + pass + + @property + @abstractmethod + def role(self) -> str: + pass + + @abstractmethod + def add_messages(self, messages): + pass + + @abstractmethod + def set_messages(self, messages): + pass + + @abstractmethod + def set_tools(self, tools): + pass + + @abstractmethod + def reset(self): + pass + + @abstractmethod + async def push_aggregation(self): + pass + + +class LLMResponseAggregator(BaseLLMResponseAggregator): def __init__( self, *, messages: List[dict], - role: str, - start_frame, - end_frame, - accumulator_frame: Type[TextFrame], - interim_accumulator_frame: Optional[Type[TextFrame]] = None, - handle_interruptions: bool = False, - expect_stripped_words: bool = True, # if True, need to add spaces between words + role: str = "user", + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self._messages = messages self._role = role - self._start_frame = start_frame - self._end_frame = end_frame - self._accumulator_frame = accumulator_frame - self._interim_accumulator_frame = interim_accumulator_frame - self._handle_interruptions = handle_interruptions - self._expect_stripped_words = expect_stripped_words - # Reset our accumulator state. - self._reset() + self._aggregation = "" + + self.reset() @property - def messages(self): + def messages(self) -> List[dict]: return self._messages @property - def role(self): + def role(self) -> str: return self._role - # - # Frame processor - # + def add_messages(self, messages): + self._messages.extend(messages) - # Use cases implemented: - # - # S: Start, E: End, T: Transcription, I: Interim, X: Text - # - # S E -> None - # S T E -> X - # S I T E -> X - # S I E T -> X - # S I E I T -> X - # S E T -> X - # S E I T -> X - # - # The following case would not be supported: - # - # S I E T1 I T2 -> X - # - # and T2 would be dropped. + def set_messages(self, messages): + self.reset() + self._messages.clear() + self._messages.extend(messages) - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) + def set_tools(self, tools): + pass - send_aggregation = False + def reset(self): + self._aggregation = "" - if isinstance(frame, self._start_frame): - self._aggregation = "" - self._aggregating = True - self._seen_start_frame = True - self._seen_end_frame = False - self._seen_interim_results = False - await self.push_frame(frame, direction) - elif isinstance(frame, self._end_frame): - self._seen_end_frame = True - self._seen_start_frame = False - - # We might have received the end frame but we might still be - # aggregating (i.e. we have seen interim results but not the final - # text). - self._aggregating = self._seen_interim_results or len(self._aggregation) == 0 - - # Send the aggregation if we are not aggregating anymore (i.e. no - # more interim results received). - send_aggregation = not self._aggregating - await self.push_frame(frame, direction) - elif isinstance(frame, self._accumulator_frame): - if self._aggregating: - if self._expect_stripped_words: - self._aggregation += f" {frame.text}" if self._aggregation else frame.text - else: - self._aggregation += frame.text - # We have recevied a complete sentence, so if we have seen the - # end frame and we were still aggregating, it means we should - # send the aggregation. - send_aggregation = self._seen_end_frame - - # We just got our final result, so let's reset interim results. - self._seen_interim_results = False - elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): - self._seen_interim_results = True - elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame): - await self._push_aggregation() - # Reset anyways - self._reset() - await self.push_frame(frame, direction) - elif isinstance(frame, LLMMessagesAppendFrame): - self._add_messages(frame.messages) - elif isinstance(frame, LLMMessagesUpdateFrame): - self._set_messages(frame.messages) - elif isinstance(frame, LLMSetToolsFrame): - self._set_tools(frame.tools) - else: - await self.push_frame(frame, direction) - - if send_aggregation: - await self._push_aggregation() - - async def _push_aggregation(self): + async def push_aggregation(self): if len(self._aggregation) > 0: self._messages.append({"role": self._role, "content": self._aggregation}) @@ -153,109 +118,22 @@ class LLMResponseAggregator(FrameProcessor): frame = LLMMessagesFrame(self._messages) await self.push_frame(frame) - # TODO-CB: Types - def _add_messages(self, messages): - self._messages.extend(messages) - def _set_messages(self, messages): - self._reset() - self._messages.clear() - self._messages.extend(messages) - - def _set_tools(self, tools): - # noop in the base class - pass - - def _reset(self): - self._aggregation = "" - self._aggregating = False - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - - -class LLMAssistantResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: List[dict] = []): - super().__init__( - messages=messages, - role="assistant", - start_frame=LLMFullResponseStartFrame, - end_frame=LLMFullResponseEndFrame, - accumulator_frame=TextFrame, - handle_interruptions=True, - ) - - -class LLMUserResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: List[dict] = []): - super().__init__( - messages=messages, - role="user", - start_frame=UserStartedSpeakingFrame, - end_frame=UserStoppedSpeakingFrame, - accumulator_frame=TranscriptionFrame, - interim_accumulator_frame=InterimTranscriptionFrame, - ) - - -class LLMFullResponseAggregator(FrameProcessor): - """This class aggregates Text frames until it receives a - LLMFullResponseEndFrame, then emits the concatenated text as - a single text frame. - - given the following frames: - - TextFrame("Hello,") - TextFrame(" world.") - TextFrame(" I am") - TextFrame(" an LLM.") - LLMFullResponseEndFrame()] - - this processor will yield nothing for the first 4 frames, then - - TextFrame("Hello, world. I am an LLM.") - LLMFullResponseEndFrame() - - when passed the last frame. - - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - ... else: - ... print(frame.__class__.__name__) - - >>> aggregator = LLMFullResponseAggregator() - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) - >>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame())) - Hello, world. I am an LLM. - LLMFullResponseEndFrame - """ - - def __init__(self): - super().__init__() - self._aggregation = "" - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - if isinstance(frame, TextFrame): - self._aggregation += frame.text - elif isinstance(frame, LLMFullResponseEndFrame): - await self.push_frame(TextFrame(self._aggregation)) - await self.push_frame(frame) - self._aggregation = "" - else: - await self.push_frame(frame, direction) - - -class LLMContextAggregator(LLMResponseAggregator): - def __init__(self, *, context: OpenAILLMContext, **kwargs): +class LLMContextResponseAggregator(BaseLLMResponseAggregator): + def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs): super().__init__(**kwargs) self._context = context + self._role = role + + self._aggregation = "" + + @property + def messages(self) -> List[dict]: + return self._context.get_messages() + + @property + def role(self) -> str: + return self._role @property def context(self): @@ -268,19 +146,18 @@ class LLMContextAggregator(LLMResponseAggregator): frame = self.get_context_frame() await self.push_frame(frame) - # TODO-CB: Types - def _add_messages(self, messages): + def add_messages(self, messages): self._context.add_messages(messages) - def _set_messages(self, messages): + def set_messages(self, messages): self._context.set_messages(messages) - def _set_tools(self, tools: List): + def set_tools(self, tools: List): self._context.set_tools(tools) - async def _push_aggregation(self): + async def push_aggregation(self): if len(self._aggregation) > 0: - self._context.add_message({"role": self._role, "content": self._aggregation}) + self._context.add_message({"role": self.role, "content": self._aggregation}) # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. @@ -290,31 +167,171 @@ class LLMContextAggregator(LLMResponseAggregator): await self.push_frame(frame) # Reset our accumulator state. - self._reset() + self.reset() -class LLMAssistantContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True): - super().__init__( - messages=[], - context=context, - role="assistant", - start_frame=LLMFullResponseStartFrame, - end_frame=LLMFullResponseEndFrame, - accumulator_frame=TextFrame, - handle_interruptions=True, - expect_stripped_words=expect_stripped_words, - ) +class LLMUserContextAggregator(LLMContextResponseAggregator): + def __init__(self, context: OpenAILLMContext, aggregation_timeout: float = 1.0, **kwargs): + super().__init__(context=context, role="user", **kwargs) + self._aggregation_timeout = aggregation_timeout + + self._seen_interim_results = False + self._user_speaking = False + + self._aggregation_event = asyncio.Event() + self._aggregation_task = None + + self.reset() + + def reset(self): + super().reset() + self._seen_interim_results = False + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartFrame): + await self._start(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, EndFrame): + await self._stop(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, CancelFrame): + await self._cancel(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, UserStartedSpeakingFrame): + await self._handle_user_started_speaking(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, UserStoppedSpeakingFrame): + await self._handle_user_stopped_speaking(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, TranscriptionFrame): + await self._handle_transcription(frame) + elif isinstance(frame, InterimTranscriptionFrame): + await self._handle_interim_transcription(frame) + elif isinstance(frame, LLMMessagesAppendFrame): + self.add_messages(frame.messages) + elif isinstance(frame, LLMMessagesUpdateFrame): + self.set_messages(frame.messages) + elif isinstance(frame, LLMSetToolsFrame): + self.set_tools(frame.tools) + else: + await self.push_frame(frame, direction) + + async def _start(self, frame: StartFrame): + self._aggregation_task = self.create_task(self._aggregation_task_handler()) + + async def _stop(self, frame: EndFrame): + if self._aggregation_task: + await self.cancel_task(self._aggregation_task) + self._aggregation_task = None + + async def _cancel(self, frame: CancelFrame): + if self._aggregation_task: + await self.cancel_task(self._aggregation_task) + self._aggregation_task = None + + async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame): + self._user_speaking = True + + async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame): + self._user_speaking = False + if not self._seen_interim_results: + await self.push_aggregation() + + async def _handle_transcription(self, frame: TranscriptionFrame): + self._aggregation += frame.text + # We just got our final result, so let's reset interim results. + self._seen_interim_results = False + # Wakeup our task. + self._aggregation_event.set() + + async def _handle_interim_transcription(self, _: InterimTranscriptionFrame): + self._seen_interim_results = True + + async def _aggregation_task_handler(self): + while True: + await self._aggregation_event.wait() + await asyncio.sleep(self._aggregation_timeout) + if not self._user_speaking: + await self.push_aggregation() + self._aggregation_event.clear() -class LLMUserContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext): - super().__init__( - messages=[], - context=context, - role="user", - start_frame=UserStartedSpeakingFrame, - end_frame=UserStoppedSpeakingFrame, - accumulator_frame=TranscriptionFrame, - interim_accumulator_frame=InterimTranscriptionFrame, - ) +class LLMAssistantContextAggregator(LLMContextResponseAggregator): + def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs): + super().__init__(context=context, role="assistant", **kwargs) + self._expect_stripped_words = expect_stripped_words + + self.reset() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartInterruptionFrame): + await self.push_aggregation() + # Reset anyways + self.reset() + await self.push_frame(frame, direction) + elif isinstance(frame, LLMFullResponseStartFrame): + await self._handle_llm_start(frame) + elif isinstance(frame, LLMFullResponseEndFrame): + await self._handle_llm_end(frame) + elif isinstance(frame, TextFrame): + await self._handle_text(frame) + else: + await self.push_frame(frame, direction) + + async def _handle_llm_start(self, _: LLMFullResponseStartFrame): + self._started = True + + async def _handle_llm_end(self, _: LLMFullResponseEndFrame): + self._started = False + await self.push_aggregation() + + async def _handle_text(self, frame: TextFrame): + if not self._started: + return + + if self._expect_stripped_words: + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + else: + self._aggregation += frame.text + + +class LLMUserResponseAggregator(LLMUserContextAggregator): + def __init__(self, messages: List[dict] = [], **kwargs): + super().__init__(context=OpenAILLMContext(messages), **kwargs) + + async def push_aggregation(self): + if len(self._aggregation) > 0: + self._context.add_message({"role": self.role, "content": self._aggregation}) + + # Reset the aggregation. Reset it before pushing it down, otherwise + # if the tasks gets cancelled we won't be able to clear things up. + self._aggregation = "" + + frame = LLMMessagesFrame(self._context.messages) + await self.push_frame(frame) + + # Reset our accumulator state. + self.reset() + + +class LLMAssistantResponseAggregator(LLMAssistantContextAggregator): + def __init__(self, messages: List[dict], **kwargs): + super().__init__(context=OpenAILLMContext(messages), **kwargs) + + async def push_aggregation(self): + if len(self._aggregation) > 0: + self._context.add_message({"role": self.role, "content": self._aggregation}) + + # Reset the aggregation. Reset it before pushing it down, otherwise + # if the tasks gets cancelled we won't be able to clear things up. + self._aggregation = "" + + frame = LLMMessagesFrame(self._context.messages) + await self.push_frame(frame) + + # Reset our accumulator state. + self.reset() diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index a593ced89..5a4799960 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -725,7 +725,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): ): self._function_call_in_progress = None self._function_call_result = frame - await self._push_aggregation() + await self.push_aggregation() else: logger.warning( "FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id" @@ -734,9 +734,9 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): self._function_call_result = None elif isinstance(frame, AnthropicImageMessageFrame): self._pending_image_frame_message = frame - await self._push_aggregation() + await self.push_aggregation() - async def _push_aggregation(self): + async def push_aggregation(self): if not ( self._aggregation or self._function_call_result or self._pending_image_frame_message ): @@ -746,7 +746,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): properties: Optional[FunctionCallResultProperties] = None aggregation = self._aggregation - self._reset() + self.reset() try: if self._function_call_result: diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 4f26b6e9f..8479a4e0a 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -115,10 +115,10 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator): class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator): - async def _push_aggregation(self): + async def push_aggregation(self): # We don't want to store any images in the context. Revisit this later when the API evolves. self._pending_image_frame_message = None - await super()._push_aggregation() + await super().push_aggregation() @dataclass diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index 28cd0d421..fbfb9a0dd 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -537,7 +537,7 @@ def language_to_google_stt_language(language: Language) -> Optional[str]: class GoogleUserContextAggregator(OpenAIUserContextAggregator): - async def _push_aggregation(self): + async def push_aggregation(self): if len(self._aggregation) > 0: self._context.add_message( glm.Content(role="user", parts=[glm.Part(text=self._aggregation)]) @@ -552,11 +552,11 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator): await self.push_frame(frame) # Reset our accumulator state. - self._reset() + self.reset() class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): - async def _push_aggregation(self): + async def push_aggregation(self): if not ( self._aggregation or self._function_call_result or self._pending_image_frame_message ): @@ -566,7 +566,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): properties: Optional[FunctionCallResultProperties] = None aggregation = self._aggregation - self._reset() + self.reset() try: if self._function_call_result: diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok.py index 7221cc09e..f9abdedec 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok.py @@ -27,7 +27,7 @@ from pipecat.services.openai import ( class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): """Custom assistant context aggregator for Grok that handles empty content requirement.""" - async def _push_aggregation(self): + async def push_aggregation(self): if not ( self._aggregation or self._function_call_result or self._pending_image_frame_message ): @@ -37,7 +37,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): properties: Optional[FunctionCallResultProperties] = None aggregation = self._aggregation - self._reset() + self.reset() try: if self._function_call_result: diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index bc251025e..0cd5fc255 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -614,7 +614,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): del self._function_calls_in_progress[frame.tool_call_id] self._function_call_result = frame # TODO-CB: Kwin wants us to refactor this out of here but I REFUSE - await self._push_aggregation() + await self.push_aggregation() else: logger.warning( "FunctionCallResultFrame tool_call_id does not match any function call in progress" @@ -622,9 +622,9 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): self._function_call_result = None elif isinstance(frame, OpenAIImageMessageFrame): self._pending_image_frame_message = frame - await self._push_aggregation() + await self.push_aggregation() - async def _push_aggregation(self): + async def push_aggregation(self): if not ( self._aggregation or self._function_call_result or self._pending_image_frame_message ): @@ -634,7 +634,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): properties: Optional[FunctionCallResultProperties] = None aggregation = self._aggregation - self._reset() + self.reset() try: if self._function_call_result: diff --git a/src/pipecat/services/openai_realtime_beta/context.py b/src/pipecat/services/openai_realtime_beta/context.py index da287194a..317817766 100644 --- a/src/pipecat/services/openai_realtime_beta/context.py +++ b/src/pipecat/services/openai_realtime_beta/context.py @@ -166,7 +166,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): if isinstance(frame, LLMSetToolsFrame): await self.push_frame(frame, direction) - async def _push_aggregation(self): + async def push_aggregation(self): # for the moment, ignore all user input coming into the pipeline. # todo: think about whether/how to fix this to allow for text input from # upstream (transport/transcription, or other sources) @@ -174,7 +174,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator): - async def _push_aggregation(self): + async def push_aggregation(self): # the only thing we implement here is function calling. in all other cases, messages # are added to the context when we receive openai realtime api events if not self._function_call_result: @@ -182,7 +182,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) properties: Optional[FunctionCallResultProperties] = None - self._reset() + self.reset() try: run_llm = True frame = self._function_call_result From 50288eeaaae09647048f94af82b4a7fef904ff5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 11 Feb 2025 22:16:20 -0800 Subject: [PATCH 05/22] tests: add LLM response aggregators tests --- tests/test_llm_response.py | 342 +++++++++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 tests/test_llm_response.py diff --git a/tests/test_llm_response.py b/tests/test_llm_response.py new file mode 100644 index 000000000..cb4620633 --- /dev/null +++ b/tests/test_llm_response.py @@ -0,0 +1,342 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.frames.frames import ( + InterimTranscriptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + TextFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMUserContextAggregator, +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.tests.utils import SleepFrame, run_test + +AGGREGATION_TIMEOUT = 0.1 +AGGREGATION_SLEEP = 0.15 + + +class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): + async def test_se(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context) + frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()] + expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + async def test_ste(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context) + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello!" + + async def test_site(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" + + async def test_st1iest2e(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context) + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hello Pipecat! ", user_id="cat", timestamp=""), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?" + + async def test_siet(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "How are you?" + + async def test_sieit(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "How are you?" + + async def test_set(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "How are you?" + + async def test_seit(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "How are you?" + + async def test_st1et2(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" + assert received_down[-1].context.messages[1]["content"] == "How are you?" + + async def test_set1t2(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="Hello Pipecat! ", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?" + + async def test_siet1it2(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="Hello Pipecat! ", user_id="cat", timestamp=""), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?" + + +class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase): + async def test_empty(self): + context = OpenAILLMContext() + aggregator = LLMAssistantContextAggregator(context) + frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()] + expected_down_frames = [] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + async def test_single(self): + context = OpenAILLMContext() + aggregator = LLMAssistantContextAggregator(context) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello Pipecat!"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [OpenAILLMContextFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" + + async def test_multiple(self): + context = OpenAILLMContext() + aggregator = LLMAssistantContextAggregator(context, expect_stripped_words=False) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello "), + TextFrame(text="Pipecat. "), + TextFrame(text="How are "), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [OpenAILLMContextFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat. How are you?" + + async def test_multiple_stripped(self): + context = OpenAILLMContext() + aggregator = LLMAssistantContextAggregator(context) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello"), + TextFrame(text="Pipecat."), + TextFrame(text="How are"), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [OpenAILLMContextFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat. How are you?" From 91a628d1ba8db75cf390c452805b1359da78ba18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 12 Feb 2025 13:53:47 -0800 Subject: [PATCH 06/22] UserResponseAggregator: implement on top of LLMUserResponseAggregator --- .../processors/aggregators/user_response.py | 147 +----------------- 1 file changed, 7 insertions(+), 140 deletions(-) diff --git a/src/pipecat/processors/aggregators/user_response.py b/src/pipecat/processors/aggregators/user_response.py index 0e2c09ebc..6998fe200 100644 --- a/src/pipecat/processors/aggregators/user_response.py +++ b/src/pipecat/processors/aggregators/user_response.py @@ -4,131 +4,15 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Optional - -from pipecat.frames.frames import ( - Frame, - InterimTranscriptionFrame, - StartInterruptionFrame, - TextFrame, - TranscriptionFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, -) -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import TextFrame +from pipecat.processors.aggregators.llm_response import LLMUserResponseAggregator -class ResponseAggregator(FrameProcessor): - """This frame processor aggregates frames between a start and an end frame - into complete text frame sentences. +class UserResponseAggregator(LLMUserResponseAggregator): + def __init__(self, **kwargs): + super().__init__(**kwargs) - For example, frame input/output: - UserStartedSpeakingFrame() -> None - TranscriptionFrame("Hello,") -> None - TranscriptionFrame(" world.") -> None - UserStoppedSpeakingFrame() -> TextFrame("Hello world.") - - Doctest: FIXME to work with asyncio - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - - >>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame, - ... end_frame=UserStoppedSpeakingFrame, - ... accumulator_frame=TranscriptionFrame, - ... pass_through=False) - >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) - >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) - >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) - >>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame())) - Hello, world. - - """ - - def __init__( - self, - *, - start_frame, - end_frame, - accumulator_frame: TextFrame, - interim_accumulator_frame: Optional[TextFrame] = None, - ): - super().__init__() - - self._start_frame = start_frame - self._end_frame = end_frame - self._accumulator_frame = accumulator_frame - self._interim_accumulator_frame = interim_accumulator_frame - - # Reset our accumulator state. - self._reset() - - # - # Frame processor - # - - # Use cases implemented: - # - # S: Start, E: End, T: Transcription, I: Interim, X: Text - # - # S E -> None - # S T E -> X - # S I T E -> X - # S I E T -> X - # S I E I T -> X - # S E T -> X - # S E I T -> X - # - # The following case would not be supported: - # - # S I E T1 I T2 -> X - # - # and T2 would be dropped. - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - send_aggregation = False - - if isinstance(frame, self._start_frame): - self._aggregating = True - self._seen_start_frame = True - self._seen_end_frame = False - self._seen_interim_results = False - await self.push_frame(frame, direction) - elif isinstance(frame, self._end_frame): - self._seen_end_frame = True - self._seen_start_frame = False - - # We might have received the end frame but we might still be - # aggregating (i.e. we have seen interim results but not the final - # text). - self._aggregating = self._seen_interim_results or len(self._aggregation) == 0 - - # Send the aggregation if we are not aggregating anymore (i.e. no - # more interim results received). - send_aggregation = not self._aggregating - await self.push_frame(frame, direction) - elif isinstance(frame, self._accumulator_frame): - if self._aggregating: - self._aggregation += f" {frame.text}" - # We have recevied a complete sentence, so if we have seen the - # end frame and we were still aggregating, it means we should - # send the aggregation. - send_aggregation = self._seen_end_frame - - # We just got our final result, so let's reset interim results. - self._seen_interim_results = False - elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): - self._seen_interim_results = True - else: - await self.push_frame(frame, direction) - - if send_aggregation: - await self._push_aggregation() - - async def _push_aggregation(self): + async def push_aggregation(self): if len(self._aggregation) > 0: frame = TextFrame(self._aggregation.strip()) @@ -139,21 +23,4 @@ class ResponseAggregator(FrameProcessor): await self.push_frame(frame) # Reset our accumulator state. - self._reset() - - def _reset(self): - self._aggregation = "" - self._aggregating = False - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - - -class UserResponseAggregator(ResponseAggregator): - def __init__(self): - super().__init__( - start_frame=UserStartedSpeakingFrame, - end_frame=UserStoppedSpeakingFrame, - accumulator_frame=TranscriptionFrame, - interim_accumulator_frame=InterimTranscriptionFrame, - ) + self.reset() From 4cbcfe2b0b049dbf133feedfeb0af1c414e87509 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 12 Feb 2025 18:53:51 -0800 Subject: [PATCH 07/22] LLMUserContextAggregator: interrupt the bot if VAD happened a while back --- .../processors/aggregators/llm_response.py | 66 ++++++++++++++---- tests/test_llm_response.py | 69 +++++++++++++++++++ 2 files changed, 120 insertions(+), 15 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 0422524f8..0ed6f2f0e 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -5,10 +5,12 @@ # import asyncio +import time from abc import abstractmethod from typing import List from pipecat.frames.frames import ( + BotInterruptionFrame, CancelFrame, EndFrame, Frame, @@ -171,12 +173,20 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator): class LLMUserContextAggregator(LLMContextResponseAggregator): - def __init__(self, context: OpenAILLMContext, aggregation_timeout: float = 1.0, **kwargs): + def __init__( + self, + context: OpenAILLMContext, + aggregation_timeout: float = 1.0, + bot_interruption_timeout: float = 2.0, + **kwargs, + ): super().__init__(context=context, role="user", **kwargs) self._aggregation_timeout = aggregation_timeout + self._bot_interruption_timeout = bot_interruption_timeout self._seen_interim_results = False self._user_speaking = False + self._last_user_speaking_time = 0 self._aggregation_event = asyncio.Event() self._aggregation_task = None @@ -219,43 +229,63 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): await self.push_frame(frame, direction) async def _start(self, frame: StartFrame): - self._aggregation_task = self.create_task(self._aggregation_task_handler()) + self._create_aggregation_task() async def _stop(self, frame: EndFrame): - if self._aggregation_task: - await self.cancel_task(self._aggregation_task) - self._aggregation_task = None + await self._cancel_aggregation_task() async def _cancel(self, frame: CancelFrame): - if self._aggregation_task: - await self.cancel_task(self._aggregation_task) - self._aggregation_task = None + await self._cancel_aggregation_task() async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame): self._user_speaking = True async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame): + self._last_user_speaking_time = time.time() self._user_speaking = False if not self._seen_interim_results: await self.push_aggregation() async def _handle_transcription(self, frame: TranscriptionFrame): self._aggregation += frame.text - # We just got our final result, so let's reset interim results. + # We just got a final result, so let's reset interim results. self._seen_interim_results = False - # Wakeup our task. + # Reset aggregation timer. self._aggregation_event.set() async def _handle_interim_transcription(self, _: InterimTranscriptionFrame): self._seen_interim_results = True + # Reset aggregation timer. + self._aggregation_event.set() + + def _create_aggregation_task(self): + self._aggregation_task = self.create_task(self._aggregation_task_handler()) + + async def _cancel_aggregation_task(self): + if self._aggregation_task: + await self.cancel_task(self._aggregation_task) + self._aggregation_task = None async def _aggregation_task_handler(self): while True: - await self._aggregation_event.wait() - await asyncio.sleep(self._aggregation_timeout) - if not self._user_speaking: - await self.push_aggregation() - self._aggregation_event.clear() + try: + await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout) + await self._maybe_push_bot_interruption() + except asyncio.TimeoutError: + if not self._user_speaking: + await self.push_aggregation() + finally: + self._aggregation_event.clear() + + async def _maybe_push_bot_interruption(self): + """If the user stopped speaking a while back and we got a transcription + frame we might want to interrupt the bot. + + """ + if not self._user_speaking: + diff_time = time.time() - self._last_user_speaking_time + if diff_time > self._bot_interruption_timeout: + await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) class LLMAssistantContextAggregator(LLMContextResponseAggregator): @@ -279,6 +309,12 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): await self._handle_llm_end(frame) elif isinstance(frame, TextFrame): await self._handle_text(frame) + elif isinstance(frame, LLMMessagesAppendFrame): + self.add_messages(frame.messages) + elif isinstance(frame, LLMMessagesUpdateFrame): + self.set_messages(frame.messages) + elif isinstance(frame, LLMSetToolsFrame): + self.set_tools(frame.tools) else: await self.push_frame(frame, direction) diff --git a/tests/test_llm_response.py b/tests/test_llm_response.py index cb4620633..e8ec92b0d 100644 --- a/tests/test_llm_response.py +++ b/tests/test_llm_response.py @@ -7,6 +7,7 @@ import unittest from pipecat.frames.frames import ( + BotInterruptionFrame, InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -27,6 +28,8 @@ from pipecat.tests.utils import SleepFrame, run_test AGGREGATION_TIMEOUT = 0.1 AGGREGATION_SLEEP = 0.15 +BOT_INTERRUPTION_TIMEOUT = 0.2 +BOT_INTERRUPTION_SLEEP = 0.25 class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): @@ -274,6 +277,72 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): ) assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?" + async def test_t(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [OpenAILLMContextFrame] + expected_up_frames = [BotInterruptionFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello!" + + async def test_it(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [OpenAILLMContextFrame] + expected_up_frames = [BotInterruptionFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" + + async def test_sie_delay_it(self): + context = OpenAILLMContext() + aggregator = LLMUserContextAggregator( + context, + aggregation_timeout=AGGREGATION_TIMEOUT, + bot_interruption_timeout=BOT_INTERRUPTION_TIMEOUT, + ) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + SleepFrame(BOT_INTERRUPTION_SLEEP), + InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [BotInterruptionFrame] + (received_down, _) = await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + assert received_down[-1].context.messages[0]["content"] == "How are you?" + class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase): async def test_empty(self): From 839aa7d93572ab420e9a1aa8f20295197f514fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 12 Feb 2025 19:14:24 -0800 Subject: [PATCH 08/22] llm_response: add some initial docstrings to LLM aggregators --- .../processors/aggregators/llm_response.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 0ed6f2f0e..2b7a12ff8 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -36,33 +36,51 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor class BaseLLMResponseAggregator(FrameProcessor): + """This is the base class for all LLM response aggregators. These + aggregators process incoming frames and aggregate content until they are + ready to push the aggregation. In the case of a user, an aggregation might + be a full transcription received from the STT service. + + The LLM response aggregators also keep a store (e.g. a message list or an + LLM context) of the current conversation, that is, it stores the messages + said by the user or by the bot. + + """ + def __init__(self, **kwargs): super().__init__(**kwargs) @property @abstractmethod def messages(self) -> List[dict]: + """Returns the messages from the current conversation.""" pass @property @abstractmethod def role(self) -> str: + """Returns the role (e.g. user, assistant...) for this aggregator.""" pass @abstractmethod def add_messages(self, messages): + """Add the given messages to the conversation.""" pass @abstractmethod def set_messages(self, messages): + """Reset the conversation with the given messages.""" pass @abstractmethod def set_tools(self, tools): + """Set LLM tools to be used in the current conversation.""" pass @abstractmethod def reset(self): + """Reset the internals of this aggregator. This should not modify the + internal messages.""" pass @abstractmethod @@ -71,6 +89,12 @@ class BaseLLMResponseAggregator(FrameProcessor): class LLMResponseAggregator(BaseLLMResponseAggregator): + """This is a base LLM aggregator that uses a simple list of messages to + store the conversation. It pushes `LLMMessagesFrame` as an aggregation + frame. + + """ + def __init__( self, *, @@ -122,6 +146,11 @@ class LLMResponseAggregator(BaseLLMResponseAggregator): class LLMContextResponseAggregator(BaseLLMResponseAggregator): + """This is a base LLM aggregator that uses an LLM context to store the + conversation. It pushes `OpenAILLMContextFrame` as an aggregation frame. + + """ + def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs): super().__init__(**kwargs) self._context = context @@ -173,6 +202,14 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator): class LLMUserContextAggregator(LLMContextResponseAggregator): + """This is a user LLM aggregator that uses an LLM context to store the + conversation. It aggregates transcriptions from the STT service and it has + logic to handle multiple scenarios where transcriptions are received between + VAD events (`UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame`) or + even outside or no VAD events at all. + + """ + def __init__( self, context: OpenAILLMContext, @@ -289,6 +326,12 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): class LLMAssistantContextAggregator(LLMContextResponseAggregator): + """This is an assistant LLM aggregator that uses an LLM context to store the + conversation. It aggregates text frames received between + `LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`. + + """ + def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs): super().__init__(context=context, role="assistant", **kwargs) self._expect_stripped_words = expect_stripped_words From 16a107948b4d30a285d61ba73e555fe995b4302a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 12 Feb 2025 23:28:36 -0800 Subject: [PATCH 09/22] services: missing kwargs in anthropic/openai user context aggregator --- src/pipecat/services/anthropic.py | 4 ++-- src/pipecat/services/openai.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 5a4799960..80b0b482e 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -651,8 +651,8 @@ class AnthropicLLMContext(OpenAILLMContext): class AnthropicUserContextAggregator(LLMUserContextAggregator): - def __init__(self, context: OpenAILLMContext | AnthropicLLMContext): - super().__init__(context=context) + def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs): + super().__init__(context=context, **kwargs) if isinstance(context, OpenAILLMContext): self._context = AnthropicLLMContext.from_openai_context(context) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 0cd5fc255..a5f52b69f 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -555,8 +555,8 @@ class OpenAIImageMessageFrame(Frame): class OpenAIUserContextAggregator(LLMUserContextAggregator): - def __init__(self, context: OpenAILLMContext): - super().__init__(context=context) + def __init__(self, context: OpenAILLMContext, **kwargs): + super().__init__(context=context, **kwargs) async def process_frame(self, frame, direction): await super().process_frame(frame, direction) From 7c815121eab2adee596be25efa412ea1ccceaf60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 12 Feb 2025 23:29:19 -0800 Subject: [PATCH 10/22] LLMContextResponseAggregator: add missing reset() implementation --- src/pipecat/processors/aggregators/llm_response.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 2b7a12ff8..850c0ca27 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -186,6 +186,9 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator): def set_tools(self, tools: List): self._context.set_tools(tools) + def reset(self): + self._aggregation = "" + async def push_aggregation(self): if len(self._aggregation) > 0: self._context.add_message({"role": self.role, "content": self._aggregation}) From b602e78625aa8b1d856d2b8ad8f842ecb0a48eaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 12 Feb 2025 23:29:45 -0800 Subject: [PATCH 11/22] tests: add OpenAI context aggregator tests --- tests/test_llm_response.py | 302 ++++++++++++++++++++++++++++--------- 1 file changed, 228 insertions(+), 74 deletions(-) diff --git a/tests/test_llm_response.py b/tests/test_llm_response.py index e8ec92b0d..1d71217a2 100644 --- a/tests/test_llm_response.py +++ b/tests/test_llm_response.py @@ -11,6 +11,7 @@ from pipecat.frames.frames import ( InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, + StartInterruptionFrame, TextFrame, TranscriptionFrame, UserStartedSpeakingFrame, @@ -24,6 +25,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) +from pipecat.services.openai import OpenAIUserContextAggregator from pipecat.tests.utils import SleepFrame, run_test AGGREGATION_TIMEOUT = 0.1 @@ -32,10 +34,16 @@ BOT_INTERRUPTION_TIMEOUT = 0.2 BOT_INTERRUPTION_SLEEP = 0.25 -class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): +class BaseTestUserContextAggregator: + CONTEXT_CLASS = None # To be set in subclasses + AGGREGATOR_CLASS = None # To be set in subclasses + async def test_se(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()] expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] await run_test( @@ -45,8 +53,11 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): ) async def test_ste(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ UserStartedSpeakingFrame(), TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), @@ -58,16 +69,19 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello!" + assert context.messages[0]["content"] == "Hello!" async def test_site(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ UserStartedSpeakingFrame(), InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""), @@ -80,16 +94,19 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" + assert context.messages[0]["content"] == "Hello Pipecat!" async def test_st1iest2e(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ UserStartedSpeakingFrame(), TranscriptionFrame(text="Hello Pipecat! ", user_id="cat", timestamp=""), @@ -108,16 +125,19 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?" + assert context.messages[0]["content"] == "Hello Pipecat! How are you?" async def test_siet(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ UserStartedSpeakingFrame(), InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), @@ -131,16 +151,19 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "How are you?" + assert context.messages[0]["content"] == "How are you?" async def test_sieit(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ UserStartedSpeakingFrame(), InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), @@ -155,16 +178,19 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "How are you?" + assert context.messages[0]["content"] == "How are you?" async def test_set(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ UserStartedSpeakingFrame(), UserStoppedSpeakingFrame(), @@ -176,16 +202,19 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "How are you?" + assert context.messages[0]["content"] == "How are you?" async def test_seit(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ UserStartedSpeakingFrame(), UserStoppedSpeakingFrame(), @@ -198,16 +227,19 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "How are you?" + assert context.messages[0]["content"] == "How are you?" async def test_st1et2(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ UserStartedSpeakingFrame(), TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), @@ -222,17 +254,20 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): OpenAILLMContextFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" - assert received_down[-1].context.messages[1]["content"] == "How are you?" + assert context.messages[0]["content"] == "Hello Pipecat!" + assert context.messages[1]["content"] == "How are you?" async def test_set1t2(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ UserStartedSpeakingFrame(), UserStoppedSpeakingFrame(), @@ -245,16 +280,19 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?" + assert context.messages[0]["content"] == "Hello Pipecat! How are you?" async def test_siet1it2(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ UserStartedSpeakingFrame(), InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), @@ -270,33 +308,39 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): UserStoppedSpeakingFrame, OpenAILLMContextFrame, ] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat! How are you?" + assert context.messages[0]["content"] == "Hello Pipecat! How are you?" async def test_t(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), ] expected_down_frames = [OpenAILLMContextFrame] expected_up_frames = [BotInterruptionFrame] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, expected_up_frames=expected_up_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello!" + assert context.messages[0]["content"] == "Hello!" async def test_it(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator(context, aggregation_timeout=AGGREGATION_TIMEOUT) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), @@ -304,17 +348,20 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): ] expected_down_frames = [OpenAILLMContextFrame] expected_up_frames = [BotInterruptionFrame] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, expected_up_frames=expected_up_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" + assert context.messages[0]["content"] == "Hello Pipecat!" async def test_sie_delay_it(self): - context = OpenAILLMContext() - aggregator = LLMUserContextAggregator( + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS( context, aggregation_timeout=AGGREGATION_TIMEOUT, bot_interruption_timeout=BOT_INTERRUPTION_TIMEOUT, @@ -335,19 +382,25 @@ class TestLLMUserContextAggreagator(unittest.IsolatedAsyncioTestCase): OpenAILLMContextFrame, ] expected_up_frames = [BotInterruptionFrame] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, expected_up_frames=expected_up_frames, ) - assert received_down[-1].context.messages[0]["content"] == "How are you?" + assert context.messages[0]["content"] == "How are you?" -class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase): +class BaseTestAssistantContextAggreagator: + CONTEXT_CLASS = None # To be set in subclasses + AGGREGATOR_CLASS = None # To be set in subclasses + async def test_empty(self): - context = OpenAILLMContext() - aggregator = LLMAssistantContextAggregator(context) + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()] expected_down_frames = [] await run_test( @@ -356,25 +409,31 @@ class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase): expected_down_frames=expected_down_frames, ) - async def test_single(self): - context = OpenAILLMContext() - aggregator = LLMAssistantContextAggregator(context) + async def test_single_text(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ LLMFullResponseStartFrame(), TextFrame(text="Hello Pipecat!"), LLMFullResponseEndFrame(), ] expected_down_frames = [OpenAILLMContextFrame] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat!" + assert context.messages[0]["content"] == "Hello Pipecat!" - async def test_multiple(self): - context = OpenAILLMContext() - aggregator = LLMAssistantContextAggregator(context, expect_stripped_words=False) + async def test_multiple_text(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) frames_to_send = [ LLMFullResponseStartFrame(), TextFrame(text="Hello "), @@ -384,16 +443,19 @@ class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase): LLMFullResponseEndFrame(), ] expected_down_frames = [OpenAILLMContextFrame] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat. How are you?" + assert context.messages[0]["content"] == "Hello Pipecat. How are you?" - async def test_multiple_stripped(self): - context = OpenAILLMContext() - aggregator = LLMAssistantContextAggregator(context) + async def test_multiple_text_stripped(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ LLMFullResponseStartFrame(), TextFrame(text="Hello"), @@ -403,9 +465,101 @@ class TestLLMAssistantContextAggreagator(unittest.IsolatedAsyncioTestCase): LLMFullResponseEndFrame(), ] expected_down_frames = [OpenAILLMContextFrame] - (received_down, _) = await run_test( + await run_test( aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert received_down[-1].context.messages[0]["content"] == "Hello Pipecat. How are you?" + assert context.messages[0]["content"] == "Hello Pipecat. How are you?" + + async def test_multiple_llm_responses(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello "), + TextFrame(text="Pipecat."), + LLMFullResponseEndFrame(), + LLMFullResponseStartFrame(), + TextFrame(text="How are "), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [OpenAILLMContextFrame, OpenAILLMContextFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == "Hello Pipecat." + assert context.messages[1]["content"] == "How are you?" + + async def test_multiple_llm_responses_interruption(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello "), + TextFrame(text="Pipecat."), + LLMFullResponseEndFrame(), + SleepFrame(AGGREGATION_SLEEP), + StartInterruptionFrame(), + LLMFullResponseStartFrame(), + TextFrame(text="How are "), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [ + OpenAILLMContextFrame, + StartInterruptionFrame, + OpenAILLMContextFrame, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == "Hello Pipecat." + assert context.messages[1]["content"] == "How are you?" + + +# +# LLMUserContextAggregator, LLMAssistantContextAggregator +# + + +class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = LLMUserContextAggregator + + +class TestLLMAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = LLMAssistantContextAggregator + + +# +# OpenAI +# + + +class TestOpenAIUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = OpenAIUserContextAggregator + + +class TestOpenAIAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = LLMAssistantContextAggregator From 9f6a1c093a5114967f702a235ae6ed9c03d6fa91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 12 Feb 2025 23:54:59 -0800 Subject: [PATCH 12/22] LLMUserContextAggregator: reset user speaking time after bot interruption --- src/pipecat/processors/aggregators/llm_response.py | 2 ++ tests/test_llm_response.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 850c0ca27..1f50119fd 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -326,6 +326,8 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): diff_time = time.time() - self._last_user_speaking_time if diff_time > self._bot_interruption_timeout: await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + # Reset time so we don't interrupt again right away. + self._last_user_speaking_time = time.time() class LLMAssistantContextAggregator(LLMContextResponseAggregator): diff --git a/tests/test_llm_response.py b/tests/test_llm_response.py index 1d71217a2..e0026f5b5 100644 --- a/tests/test_llm_response.py +++ b/tests/test_llm_response.py @@ -343,6 +343,7 @@ class BaseTestUserContextAggregator: aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) frames_to_send = [ InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), + SleepFrame(), TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), ] From 84510fd5216870ad47fe25e17fb5694df5596841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 07:24:05 -0800 Subject: [PATCH 13/22] LLMUserContextAggregator: add space between transcriptions --- src/pipecat/processors/aggregators/llm_response.py | 2 +- tests/{test_llm_response.py => test_context_aggregators.py} | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename tests/{test_llm_response.py => test_context_aggregators.py} (98%) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 1f50119fd..0bc68efdb 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -287,7 +287,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): await self.push_aggregation() async def _handle_transcription(self, frame: TranscriptionFrame): - self._aggregation += frame.text + self._aggregation += f" {frame.text}" if self._aggregation else frame.text # We just got a final result, so let's reset interim results. self._seen_interim_results = False # Reset aggregation timer. diff --git a/tests/test_llm_response.py b/tests/test_context_aggregators.py similarity index 98% rename from tests/test_llm_response.py rename to tests/test_context_aggregators.py index e0026f5b5..a8e69da09 100644 --- a/tests/test_llm_response.py +++ b/tests/test_context_aggregators.py @@ -109,7 +109,7 @@ class BaseTestUserContextAggregator: aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ UserStartedSpeakingFrame(), - TranscriptionFrame(text="Hello Pipecat! ", user_id="cat", timestamp=""), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), SleepFrame(), UserStoppedSpeakingFrame(), @@ -271,7 +271,7 @@ class BaseTestUserContextAggregator: frames_to_send = [ UserStartedSpeakingFrame(), UserStoppedSpeakingFrame(), - TranscriptionFrame(text="Hello Pipecat! ", user_id="cat", timestamp=""), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), ] @@ -298,7 +298,7 @@ class BaseTestUserContextAggregator: InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), SleepFrame(), UserStoppedSpeakingFrame(), - TranscriptionFrame(text="Hello Pipecat! ", user_id="cat", timestamp=""), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), From 463078e3756c91bbc79afa9b8f51153ceb2b90d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 11:48:04 -0800 Subject: [PATCH 14/22] initialize assistant aggregators with context and push upstream instead --- .../processors/aggregators/llm_response.py | 4 +- src/pipecat/services/anthropic.py | 14 +++--- src/pipecat/services/google/google.py | 2 +- src/pipecat/services/grok.py | 3 +- src/pipecat/services/openai.py | 9 ++-- .../services/openai_realtime_beta/context.py | 6 +-- tests/test_context_aggregators.py | 48 ++++++++++--------- 7 files changed, 44 insertions(+), 42 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 0bc68efdb..bb2e28ca7 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -173,9 +173,9 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator): def get_context_frame(self) -> OpenAILLMContextFrame: return OpenAILLMContextFrame(context=self._context) - async def push_context_frame(self): + async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM): frame = self.get_context_frame() - await self.push_frame(frame) + await self.push_frame(frame, direction) def add_messages(self, messages): self._context.add_messages(messages) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 80b0b482e..d74820b02 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -126,9 +126,11 @@ class AnthropicLLMService(LLMService): def create_context_aggregator( context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True ) -> AnthropicContextAggregatorPair: + if isinstance(context, OpenAILLMContext): + context = AnthropicLLMContext.from_openai_context(context) user = AnthropicUserContextAggregator(context) assistant = AnthropicAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) @@ -654,9 +656,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator): def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs): super().__init__(context=context, **kwargs) - if isinstance(context, OpenAILLMContext): - self._context = AnthropicLLMContext.from_openai_context(context) - async def process_frame(self, frame, direction): await super().process_frame(frame, direction) # Our parent method has already called push_frame(). So we can't interrupt the @@ -703,9 +702,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator): class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs): - super().__init__(context=user_context_aggregator._context, **kwargs) - self._user_context_aggregator = user_context_aggregator + def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs): + super().__init__(context=context, **kwargs) self._function_call_in_progress = None self._function_call_result = None self._pending_image_frame_message = None @@ -799,7 +797,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index fbfb9a0dd..de53f9972 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -626,7 +626,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok.py index f9abdedec..5d1a731ff 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok.py @@ -17,6 +17,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) +from pipecat.processors.frame_processor import FrameDirection from pipecat.services.openai import ( OpenAIAssistantContextAggregator, OpenAILLMService, @@ -91,7 +92,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index a5f52b69f..d3628f4cb 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -355,7 +355,7 @@ class OpenAILLMService(BaseOpenAILLMService): ) -> OpenAIContextAggregatorPair: user = OpenAIUserContextAggregator(context) assistant = OpenAIAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) @@ -592,9 +592,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator): class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs): - super().__init__(context=user_context_aggregator._context, **kwargs) - self._user_context_aggregator = user_context_aggregator + def __init__(self, context: OpenAILLMContext, **kwargs): + super().__init__(context=context, **kwargs) self._function_calls_in_progress = {} self._function_call_result = None self._pending_image_frame_message = None @@ -686,7 +685,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/src/pipecat/services/openai_realtime_beta/context.py b/src/pipecat/services/openai_realtime_beta/context.py index 317817766..d88ed3314 100644 --- a/src/pipecat/services/openai_realtime_beta/context.py +++ b/src/pipecat/services/openai_realtime_beta/context.py @@ -217,8 +217,8 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) # The standard function callback code path pushes the FunctionCallResultFrame from the llm itself, # so we didn't have a chance to add the result to the openai realtime api context. Let's push a # special frame to do that. - await self._user_context_aggregator.push_frame( - RealtimeFunctionCallResultFrame(result_frame=frame) + await self.push_frame( + RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM ) if properties and properties.run_llm is not None: # If the tool call result has a run_llm property, use it @@ -228,7 +228,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) run_llm = not bool(self._function_calls_in_progress) if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index a8e69da09..afc00abac 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -11,6 +11,7 @@ from pipecat.frames.frames import ( InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, + OpenAILLMContextAssistantTimestampFrame, StartInterruptionFrame, TextFrame, TranscriptionFrame, @@ -25,7 +26,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) -from pipecat.services.openai import OpenAIUserContextAggregator +from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator from pipecat.tests.utils import SleepFrame, run_test AGGREGATION_TIMEOUT = 0.1 @@ -37,6 +38,7 @@ BOT_INTERRUPTION_SLEEP = 0.25 class BaseTestUserContextAggregator: CONTEXT_CLASS = None # To be set in subclasses AGGREGATOR_CLASS = None # To be set in subclasses + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] async def test_se(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -67,7 +69,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -92,7 +94,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -123,7 +125,7 @@ class BaseTestUserContextAggregator: UserStoppedSpeakingFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -149,7 +151,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -176,7 +178,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -200,7 +202,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -225,7 +227,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -251,8 +253,8 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -278,7 +280,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -306,7 +308,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -325,7 +327,7 @@ class BaseTestUserContextAggregator: TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] expected_up_frames = [BotInterruptionFrame] await run_test( aggregator, @@ -347,7 +349,7 @@ class BaseTestUserContextAggregator: TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] expected_up_frames = [BotInterruptionFrame] await run_test( aggregator, @@ -380,7 +382,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] expected_up_frames = [BotInterruptionFrame] await run_test( @@ -395,6 +397,7 @@ class BaseTestUserContextAggregator: class BaseTestAssistantContextAggreagator: CONTEXT_CLASS = None # To be set in subclasses AGGREGATOR_CLASS = None # To be set in subclasses + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] async def test_empty(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -421,7 +424,7 @@ class BaseTestAssistantContextAggreagator: TextFrame(text="Hello Pipecat!"), LLMFullResponseEndFrame(), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -443,7 +446,7 @@ class BaseTestAssistantContextAggreagator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -465,7 +468,7 @@ class BaseTestAssistantContextAggreagator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -489,7 +492,7 @@ class BaseTestAssistantContextAggreagator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [OpenAILLMContextFrame, OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -517,9 +520,9 @@ class BaseTestAssistantContextAggreagator: LLMFullResponseEndFrame(), ] expected_down_frames = [ - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, StartInterruptionFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -563,4 +566,5 @@ class TestOpenAIAssistantContextAggregator( BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = LLMAssistantContextAggregator + AGGREGATOR_CLASS = OpenAIAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] From b28f752afa54b3b837b3ea77ea181d5c58c3ad61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 13:08:34 -0800 Subject: [PATCH 15/22] tests: add anthropic and google aggregator tests --- src/pipecat/tests/utils.py | 2 + tests/test_context_aggregators.py | 142 +++++++++++++++++++++++++----- 2 files changed, 123 insertions(+), 21 deletions(-) diff --git a/src/pipecat/tests/utils.py b/src/pipecat/tests/utils.py index 2b78f2bef..55bda9cea 100644 --- a/src/pipecat/tests/utils.py +++ b/src/pipecat/tests/utils.py @@ -127,6 +127,7 @@ async def run_test( received_down_frames.append(frame) print("received DOWN frames =", received_down_frames) + print("expected DOWN frames =", expected_down_frames) assert len(received_down_frames) == len(expected_down_frames) @@ -142,6 +143,7 @@ async def run_test( received_up_frames.append(frame) print("received UP frames =", received_up_frames) + print("expected UP frames =", expected_up_frames) assert len(received_up_frames) == len(expected_up_frames) diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index afc00abac..7190afce2 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -6,6 +6,8 @@ import unittest +import google.ai.generativelanguage as glm + from pipecat.frames.frames import ( BotInterruptionFrame, InterimTranscriptionFrame, @@ -26,6 +28,16 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) +from pipecat.services.anthropic import ( + AnthropicAssistantContextAggregator, + AnthropicLLMContext, + AnthropicUserContextAggregator, +) +from pipecat.services.google.google import ( + GoogleAssistantContextAggregator, + GoogleLLMContext, + GoogleUserContextAggregator, +) from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator from pipecat.tests.utils import SleepFrame, run_test @@ -40,6 +52,14 @@ class BaseTestUserContextAggregator: AGGREGATOR_CLASS = None # To be set in subclasses EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] + def check_message_content(self, context: OpenAILLMContext, index: int, content: str): + assert context.messages[index]["content"] == content + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + assert context.messages[index]["content"] == content + async def test_se(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" @@ -76,7 +96,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello!" + self.check_message_content(context, 0, "Hello!") async def test_site(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -101,7 +121,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat!" + self.check_message_content(context, 0, "Hello Pipecat!") async def test_st1iest2e(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -132,7 +152,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat! How are you?" + self.check_message_content(context, 0, "Hello Pipecat! How are you?") async def test_siet(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -158,7 +178,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "How are you?" + self.check_message_content(context, 0, "How are you?") async def test_sieit(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -185,7 +205,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "How are you?" + self.check_message_content(context, 0, "How are you?") async def test_set(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -209,7 +229,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "How are you?" + self.check_message_content(context, 0, "How are you?") async def test_seit(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -234,7 +254,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "How are you?" + self.check_message_content(context, 0, "How are you?") async def test_st1et2(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -261,8 +281,8 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat!" - assert context.messages[1]["content"] == "How are you?" + self.check_message_multi_content(context, 0, 0, "Hello Pipecat!") + self.check_message_multi_content(context, 0, 1, "How are you?") async def test_set1t2(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -287,7 +307,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat! How are you?" + self.check_message_content(context, 0, "Hello Pipecat! How are you?") async def test_siet1it2(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -315,7 +335,7 @@ class BaseTestUserContextAggregator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat! How are you?" + self.check_message_content(context, 0, "Hello Pipecat! How are you?") async def test_t(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -335,7 +355,7 @@ class BaseTestUserContextAggregator: expected_down_frames=expected_down_frames, expected_up_frames=expected_up_frames, ) - assert context.messages[0]["content"] == "Hello!" + self.check_message_content(context, 0, "Hello!") async def test_it(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -357,7 +377,7 @@ class BaseTestUserContextAggregator: expected_down_frames=expected_down_frames, expected_up_frames=expected_up_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat!" + self.check_message_content(context, 0, "Hello Pipecat!") async def test_sie_delay_it(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -391,7 +411,7 @@ class BaseTestUserContextAggregator: expected_down_frames=expected_down_frames, expected_up_frames=expected_up_frames, ) - assert context.messages[0]["content"] == "How are you?" + self.check_message_content(context, 0, "How are you?") class BaseTestAssistantContextAggreagator: @@ -399,6 +419,14 @@ class BaseTestAssistantContextAggreagator: AGGREGATOR_CLASS = None # To be set in subclasses EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] + def check_message_content(self, context: OpenAILLMContext, index: int, content: str): + assert context.messages[index]["content"] == content + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + assert context.messages[index]["content"] == content + async def test_empty(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" @@ -430,7 +458,7 @@ class BaseTestAssistantContextAggreagator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat!" + self.check_message_content(context, 0, "Hello Pipecat!") async def test_multiple_text(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -452,7 +480,7 @@ class BaseTestAssistantContextAggreagator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat. How are you?" + self.check_message_content(context, 0, "Hello Pipecat. How are you?") async def test_multiple_text_stripped(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -474,7 +502,7 @@ class BaseTestAssistantContextAggreagator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat. How are you?" + self.check_message_content(context, 0, "Hello Pipecat. How are you?") async def test_multiple_llm_responses(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -498,8 +526,8 @@ class BaseTestAssistantContextAggreagator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat." - assert context.messages[1]["content"] == "How are you?" + self.check_message_multi_content(context, 0, 0, "Hello Pipecat.") + self.check_message_multi_content(context, 0, 1, "How are you?") async def test_multiple_llm_responses_interruption(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -529,8 +557,8 @@ class BaseTestAssistantContextAggreagator: frames_to_send=frames_to_send, expected_down_frames=expected_down_frames, ) - assert context.messages[0]["content"] == "Hello Pipecat." - assert context.messages[1]["content"] == "How are you?" + self.check_message_multi_content(context, 0, 0, "Hello Pipecat.") + self.check_message_multi_content(context, 0, 1, "How are you?") # @@ -568,3 +596,75 @@ class TestOpenAIAssistantContextAggregator( CONTEXT_CLASS = OpenAILLMContext AGGREGATOR_CLASS = OpenAIAssistantContextAggregator EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] + + +# +# Anthropic +# + + +class TestAnthropicUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = AnthropicLLMContext + AGGREGATOR_CLASS = AnthropicUserContextAggregator + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + messages = context.messages[content_index] + assert messages["content"][index]["text"] == content + + +class TestAnthropicAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = AnthropicLLMContext + AGGREGATOR_CLASS = AnthropicAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + messages = context.messages[content_index] + assert messages["content"][index]["text"] == content + + +# +# Google +# + + +class TestGoogleUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = GoogleLLMContext + AGGREGATOR_CLASS = GoogleUserContextAggregator + + def check_message_content(self, context: OpenAILLMContext, index: int, content: str): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["text"] == content + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["text"] == content + + +class TestGoogleAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = GoogleLLMContext + AGGREGATOR_CLASS = GoogleAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] + + def check_message_content(self, context: OpenAILLMContext, index: int, content: str): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["text"] == content + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["text"] == content From 67cdc0063a5bd7a7d6da76b24a09b905dce46738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 13:17:57 -0800 Subject: [PATCH 16/22] BaseTransportOutput: allow pushing frames upstream --- CHANGELOG.md | 3 +++ src/pipecat/transports/base_output.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 634a740e7..a0dcf9af1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed a `BaseOutputTransport` issue that was causing upstream frames to no be + pushed upstream. + - Fixed multiple issue where user transcriptions where not being handled properly. It was possible for short utterances to not trigger VAD which would cause user transcriptions to be ignored. It was also possible for one or more diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 6bf4a9ff3..1b6d1e833 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -170,6 +170,8 @@ class BaseOutputTransport(FrameProcessor): # TODO(aleix): Images and audio should support presentation timestamps. elif frame.pts: await self._sink_clock_queue.put((frame.pts, frame.id, frame)) + elif direction == FrameDirection.UPSTREAM: + await self.push_frame(frame, direction) else: await self._sink_queue.put(frame) From 99779046a8afba96deaf8220f87607ac602f3c7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 13:19:46 -0800 Subject: [PATCH 17/22] services: use push_context_frame() --- src/pipecat/services/anthropic.py | 3 +-- src/pipecat/services/google/google.py | 3 +-- src/pipecat/services/grok.py | 3 +-- src/pipecat/services/openai.py | 3 +-- src/pipecat/services/openai_realtime_beta/context.py | 3 +-- 5 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index d74820b02..80235199b 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -804,8 +804,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): await properties.on_context_updated() # Push context frame - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() # Push timestamp frame with current time timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index de53f9972..de88050c0 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -633,8 +633,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): await properties.on_context_updated() # Push context frame - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() # Push timestamp frame with current time timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok.py index 5d1a731ff..1d1eb40d7 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok.py @@ -98,8 +98,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): if properties and properties.on_context_updated is not None: await properties.on_context_updated() - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() except Exception as e: logger.error(f"Error processing frame: {e}") diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index d3628f4cb..b0a66eba0 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -692,8 +692,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): await properties.on_context_updated() # Push context frame - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() # Push timestamp frame with current time timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) diff --git a/src/pipecat/services/openai_realtime_beta/context.py b/src/pipecat/services/openai_realtime_beta/context.py index d88ed3314..31639dc6b 100644 --- a/src/pipecat/services/openai_realtime_beta/context.py +++ b/src/pipecat/services/openai_realtime_beta/context.py @@ -234,8 +234,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) if properties and properties.on_context_updated is not None: await properties.on_context_updated() - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() except Exception as e: logger.error(f"Error processing frame: {e}") From e0d24d7fc09bbe0e9617b90c976a1436f9ca4326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 13:21:32 -0800 Subject: [PATCH 18/22] update CHANGELOG --- CHANGELOG.md | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0dcf9af1..3b1a3bafd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,13 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added a new `audio_in_stream_on_start` field to `TransportParams`. - -- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`. - - This method should be used to start receiving the input audio in case the field `audio_in_stream_on_start` is set to `false`. -- Added support for the `RTVIProcessor` to handle buffered audio in `base64` format, converting it into InputAudioRawFrame for transport. +- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`. -- Added support for the `RTVIProcessor` to trigger `start_audio_in_streaming` only after the `client-ready` message. + - This method should be used to start receiving the input audio in case the + field `audio_in_stream_on_start` is set to `false`. + +- Added support for the `RTVIProcessor` to handle buffered audio in `base64` + format, converting it into InputAudioRawFrame for transport. + +- Added support for the `RTVIProcessor` to trigger `start_audio_in_streaming` + only after the `client-ready` message. - Added new `MUTE_UNTIL_FIRST_BOT_COMPLETE` strategy to `STTMuteStrategy`. This strategy starts muted and remains muted until the first bot speech completes, @@ -45,11 +49,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Updated `DailyTransport` to respect the `audio_in_stream_on_start` field, ensuring it only starts receiving the audio input if it is enabled. +- Updated `DailyTransport` to respect the `audio_in_stream_on_start` field, + ensuring it only starts receiving the audio input if it is enabled. -- Updated `FastAPIWebsocketOutputTransport` to send `TransportMessageFrame` and `TransportMessageUrgentFrame` to the serializer. +- Updated `FastAPIWebsocketOutputTransport` to send `TransportMessageFrame` and + `TransportMessageUrgentFrame` to the serializer. -- Updated `WebsocketServerOutputTransport` to send `TransportMessageFrame` and `TransportMessageUrgentFrame` to the serializer. +- Updated `WebsocketServerOutputTransport` to send `TransportMessageFrame` and + `TransportMessageUrgentFrame` to the serializer. - Enhanced `STTMuteConfig` to validate strategy combinations, preventing `MUTE_UNTIL_FIRST_BOT_COMPLETE` and `FIRST_SPEECH` from being used together From a6502df72c8c1e2351679b543820b97417ebfb54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 13:50:33 -0800 Subject: [PATCH 19/22] services: forgot to pass context instead of user aggregator --- src/pipecat/services/gemini_multimodal_live/gemini.py | 2 +- src/pipecat/services/google/google.py | 2 +- src/pipecat/services/grok.py | 2 +- src/pipecat/services/openai_realtime_beta/openai.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 8479a4e0a..6e7a1c0fa 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -706,6 +706,6 @@ class GeminiMultimodalLiveLLMService(LLMService): GeminiMultimodalLiveContext.upgrade(context) user = GeminiMultimodalLiveUserContextAggregator(context) assistant = GeminiMultimodalLiveAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index de88050c0..605c74069 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -1174,7 +1174,7 @@ class GoogleLLMService(LLMService): ) -> GoogleContextAggregatorPair: user = GoogleUserContextAggregator(context) assistant = GoogleAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return GoogleContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok.py index 1d1eb40d7..064dd8829 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok.py @@ -212,6 +212,6 @@ class GrokLLMService(OpenAILLMService): ) -> GrokContextAggregatorPair: user = OpenAIUserContextAggregator(context) assistant = GrokAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return GrokContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index a42ed7b9a..bc173d765 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -568,6 +568,6 @@ class OpenAIRealtimeBetaLLMService(LLMService): OpenAIRealtimeLLMContext.upgrade_to_realtime(context) user = OpenAIRealtimeUserContextAggregator(context) assistant = OpenAIRealtimeAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) From 5909dff42305ed1991feecfe40f784ff855d3c26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 13:59:43 -0800 Subject: [PATCH 20/22] LLMContextResponseAggregator: add VAD emulation support --- CHANGELOG.md | 4 ++++ src/pipecat/frames/frames.py | 16 ++++++++++++++ .../processors/aggregators/llm_response.py | 19 +++++++++++++++- src/pipecat/transports/base_input.py | 22 ++++++++++++++----- tests/test_context_aggregators.py | 9 ++++---- 5 files changed, 60 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b1a3bafd..ca7d08e3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added new frames `EmulateUserStartedSpeakingFrame` and + `EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior + without VAD being present or not being triggered. + - Added a new `audio_in_stream_on_start` field to `TransportParams`. - Added a new method `start_audio_in_streaming` in the `BaseInputTransport`. diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 09d0a93b0..c9b812a1c 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -565,6 +565,22 @@ class UserStoppedSpeakingFrame(SystemFrame): pass +@dataclass +class EmulateUserStartedSpeakingFrame(SystemFrame): + """Emitted by internal processors upstream to emulate VAD behavior when a + user starts speaking.""" + + pass + + +@dataclass +class EmulateUserStoppedSpeakingFrame(SystemFrame): + """Emitted by internal processors upstream to emulate VAD behavior when a + user stops speaking.""" + + pass + + @dataclass class BotInterruptionFrame(SystemFrame): """Emitted by when the bot should be interrupted. This will mainly cause the diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index bb2e28ca7..950c155e6 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -12,6 +12,8 @@ from typing import List from pipecat.frames.frames import ( BotInterruptionFrame, CancelFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, EndFrame, Frame, InterimTranscriptionFrame, @@ -227,6 +229,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): self._seen_interim_results = False self._user_speaking = False self._last_user_speaking_time = 0 + self._emulating_vad = False self._aggregation_event = asyncio.Event() self._aggregation_task = None @@ -314,6 +317,14 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): except asyncio.TimeoutError: if not self._user_speaking: await self.push_aggregation() + + # If we are emulating VAD we still need to send the user stopped + # speaking frame. + if self._emulating_vad: + await self.push_frame( + EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM + ) + self._emulating_vad = False finally: self._aggregation_event.clear() @@ -325,7 +336,13 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): if not self._user_speaking: diff_time = time.time() - self._last_user_speaking_time if diff_time > self._bot_interruption_timeout: - await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + # If we reach this case we received a transcription but VAD was + # not able to detect voice (e.g. when you whisper a short + # utterance). So, we need to emulate VAD (i.e. user + # start/stopped speaking). + await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM) + self._emulating_vad = True + # Reset time so we don't interrupt again right away. self._last_user_speaking_time = time.time() diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 6bfe86001..42eb162da 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -14,6 +14,8 @@ from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState from pipecat.frames.frames import ( BotInterruptionFrame, CancelFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, EndFrame, FilterUpdateSettingsFrame, Frame, @@ -112,9 +114,13 @@ class BaseInputTransport(FrameProcessor): await self.cancel(frame) await self.push_frame(frame, direction) elif isinstance(frame, BotInterruptionFrame): - logger.debug("Bot interruption") - await self._start_interruption() - await self.push_frame(StartInterruptionFrame()) + await self._handle_bot_interruption(frame) + elif isinstance(frame, EmulateUserStartedSpeakingFrame): + logger.debug("Emulating user started speaking") + await self._handle_user_interruption(UserStartedSpeakingFrame()) + elif isinstance(frame, EmulateUserStoppedSpeakingFrame): + logger.debug("Emulating user stopped speaking") + await self._handle_user_interruption(UserStoppedSpeakingFrame()) # All other system frames elif isinstance(frame, SystemFrame): await self.push_frame(frame, direction) @@ -137,7 +143,13 @@ class BaseInputTransport(FrameProcessor): # Handle interruptions # - async def _handle_interruptions(self, frame: Frame): + async def _handle_bot_interruption(self, frame: BotInterruptionFrame): + logger.debug("Bot interruption") + if self.interruptions_allowed: + await self._start_interruption() + await self.push_frame(StartInterruptionFrame()) + + async def _handle_user_interruption(self, frame: Frame): if isinstance(frame, UserStartedSpeakingFrame): logger.debug("User started speaking") # Make sure we notify about interruptions quickly out-of-band. @@ -183,7 +195,7 @@ class BaseInputTransport(FrameProcessor): frame = UserStoppedSpeakingFrame() if frame: - await self._handle_interruptions(frame) + await self._handle_user_interruption(frame) vad_state = new_vad_state return vad_state diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index 7190afce2..d4b8c35ce 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -9,7 +9,8 @@ import unittest import google.ai.generativelanguage as glm from pipecat.frames.frames import ( - BotInterruptionFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -348,7 +349,7 @@ class BaseTestUserContextAggregator: SleepFrame(sleep=AGGREGATION_SLEEP), ] expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] - expected_up_frames = [BotInterruptionFrame] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] await run_test( aggregator, frames_to_send=frames_to_send, @@ -370,7 +371,7 @@ class BaseTestUserContextAggregator: SleepFrame(sleep=AGGREGATION_SLEEP), ] expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] - expected_up_frames = [BotInterruptionFrame] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] await run_test( aggregator, frames_to_send=frames_to_send, @@ -404,7 +405,7 @@ class BaseTestUserContextAggregator: UserStoppedSpeakingFrame, *self.EXPECTED_CONTEXT_FRAMES, ] - expected_up_frames = [BotInterruptionFrame] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] await run_test( aggregator, frames_to_send=frames_to_send, From 7578fbeaefba744b3d3d1af6b61e267e39f05684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 14:07:30 -0800 Subject: [PATCH 21/22] update google requirements --- pyproject.toml | 2 +- test-requirements.txt | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0578c9622..3c5d6262b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ elevenlabs = [ "websockets~=13.1" ] fal = [ "fal-client~=0.5.6" ] fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ] gladia = [ "websockets~=13.1" ] -google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.24.0", "google-genai~=1.0.0", "google-cloud-speech~=2.30.0" ] +google = [ "google-cloud-speech~=2.31.0", "google-cloud-texttospeech~=2.25.0", "google-genai~=1.2.0", "google-generativeai~=0.8.4" ] grok = [ "openai~=1.59.6" ] groq = [ "openai~=1.59.6" ] gstreamer = [ "pygobject~=3.50.0" ] diff --git a/test-requirements.txt b/test-requirements.txt index 999c7a6ba..7566de351 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -7,8 +7,10 @@ deepgram-sdk~=3.5.0 fal-client~=0.4.1 fastapi~=0.115.0 faster-whisper~=1.0.3 -google-cloud-texttospeech~=2.21.1 -google-generativeai~=0.8.3 +google-cloud-speech~=2.31.0 +google-cloud-texttospeech~=2.25.0 +google-genai~=1.2.0 +google-generativeai~=0.8.4 langchain~=0.2.14 livekit~=0.13.1 lmnt~=1.1.4 From c74440965124ea9eecd7fce6165273f1b56a504e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 14:49:00 -0800 Subject: [PATCH 22/22] SegmentedSTTService: fix process_audio_frame() arguments --- src/pipecat/services/ai_services.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index b8d223176..ac1c4582d 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -577,7 +577,7 @@ class SegmentedSTTService(STTService): self._smoothing_factor = 0.2 self._prev_volume = 0 - async def process_audio_frame(self, frame: AudioRawFrame): + async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection): # Try to filter out empty background noise volume = self._get_smoothed_volume(frame) if volume >= self._min_volume: