From 463078e3756c91bbc79afa9b8f51153ceb2b90d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 13 Feb 2025 11:48:04 -0800 Subject: [PATCH] initialize assistant aggregators with context and push upstream instead --- .../processors/aggregators/llm_response.py | 4 +- src/pipecat/services/anthropic.py | 14 +++--- src/pipecat/services/google/google.py | 2 +- src/pipecat/services/grok.py | 3 +- src/pipecat/services/openai.py | 9 ++-- .../services/openai_realtime_beta/context.py | 6 +-- tests/test_context_aggregators.py | 48 ++++++++++--------- 7 files changed, 44 insertions(+), 42 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 0bc68efdb..bb2e28ca7 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -173,9 +173,9 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator): def get_context_frame(self) -> OpenAILLMContextFrame: return OpenAILLMContextFrame(context=self._context) - async def push_context_frame(self): + async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM): frame = self.get_context_frame() - await self.push_frame(frame) + await self.push_frame(frame, direction) def add_messages(self, messages): self._context.add_messages(messages) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 80b0b482e..d74820b02 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -126,9 +126,11 @@ class AnthropicLLMService(LLMService): def create_context_aggregator( context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True ) -> AnthropicContextAggregatorPair: + if isinstance(context, OpenAILLMContext): + context = AnthropicLLMContext.from_openai_context(context) user = AnthropicUserContextAggregator(context) assistant = AnthropicAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) @@ -654,9 +656,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator): def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs): super().__init__(context=context, **kwargs) - if isinstance(context, OpenAILLMContext): - self._context = AnthropicLLMContext.from_openai_context(context) - async def process_frame(self, frame, direction): await super().process_frame(frame, direction) # Our parent method has already called push_frame(). So we can't interrupt the @@ -703,9 +702,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator): class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs): - super().__init__(context=user_context_aggregator._context, **kwargs) - self._user_context_aggregator = user_context_aggregator + def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs): + super().__init__(context=context, **kwargs) self._function_call_in_progress = None self._function_call_result = None self._pending_image_frame_message = None @@ -799,7 +797,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index fbfb9a0dd..de53f9972 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -626,7 +626,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok.py index f9abdedec..5d1a731ff 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok.py @@ -17,6 +17,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) +from pipecat.processors.frame_processor import FrameDirection from pipecat.services.openai import ( OpenAIAssistantContextAggregator, OpenAILLMService, @@ -91,7 +92,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index a5f52b69f..d3628f4cb 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -355,7 +355,7 @@ class OpenAILLMService(BaseOpenAILLMService): ) -> OpenAIContextAggregatorPair: user = OpenAIUserContextAggregator(context) assistant = OpenAIAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) @@ -592,9 +592,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator): class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs): - super().__init__(context=user_context_aggregator._context, **kwargs) - self._user_context_aggregator = user_context_aggregator + def __init__(self, context: OpenAILLMContext, **kwargs): + super().__init__(context=context, **kwargs) self._function_calls_in_progress = {} self._function_call_result = None self._pending_image_frame_message = None @@ -686,7 +685,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/src/pipecat/services/openai_realtime_beta/context.py b/src/pipecat/services/openai_realtime_beta/context.py index 317817766..d88ed3314 100644 --- a/src/pipecat/services/openai_realtime_beta/context.py +++ b/src/pipecat/services/openai_realtime_beta/context.py @@ -217,8 +217,8 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) # The standard function callback code path pushes the FunctionCallResultFrame from the llm itself, # so we didn't have a chance to add the result to the openai realtime api context. Let's push a # special frame to do that. - await self._user_context_aggregator.push_frame( - RealtimeFunctionCallResultFrame(result_frame=frame) + await self.push_frame( + RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM ) if properties and properties.run_llm is not None: # If the tool call result has a run_llm property, use it @@ -228,7 +228,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) run_llm = not bool(self._function_calls_in_progress) if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index a8e69da09..afc00abac 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -11,6 +11,7 @@ from pipecat.frames.frames import ( InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, + OpenAILLMContextAssistantTimestampFrame, StartInterruptionFrame, TextFrame, TranscriptionFrame, @@ -25,7 +26,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) -from pipecat.services.openai import OpenAIUserContextAggregator +from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator from pipecat.tests.utils import SleepFrame, run_test AGGREGATION_TIMEOUT = 0.1 @@ -37,6 +38,7 @@ BOT_INTERRUPTION_SLEEP = 0.25 class BaseTestUserContextAggregator: CONTEXT_CLASS = None # To be set in subclasses AGGREGATOR_CLASS = None # To be set in subclasses + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] async def test_se(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -67,7 +69,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -92,7 +94,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -123,7 +125,7 @@ class BaseTestUserContextAggregator: UserStoppedSpeakingFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -149,7 +151,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -176,7 +178,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -200,7 +202,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -225,7 +227,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -251,8 +253,8 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -278,7 +280,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -306,7 +308,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -325,7 +327,7 @@ class BaseTestUserContextAggregator: TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] expected_up_frames = [BotInterruptionFrame] await run_test( aggregator, @@ -347,7 +349,7 @@ class BaseTestUserContextAggregator: TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] expected_up_frames = [BotInterruptionFrame] await run_test( aggregator, @@ -380,7 +382,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] expected_up_frames = [BotInterruptionFrame] await run_test( @@ -395,6 +397,7 @@ class BaseTestUserContextAggregator: class BaseTestAssistantContextAggreagator: CONTEXT_CLASS = None # To be set in subclasses AGGREGATOR_CLASS = None # To be set in subclasses + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] async def test_empty(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" @@ -421,7 +424,7 @@ class BaseTestAssistantContextAggreagator: TextFrame(text="Hello Pipecat!"), LLMFullResponseEndFrame(), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -443,7 +446,7 @@ class BaseTestAssistantContextAggreagator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -465,7 +468,7 @@ class BaseTestAssistantContextAggreagator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -489,7 +492,7 @@ class BaseTestAssistantContextAggreagator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [OpenAILLMContextFrame, OpenAILLMContextFrame] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -517,9 +520,9 @@ class BaseTestAssistantContextAggreagator: LLMFullResponseEndFrame(), ] expected_down_frames = [ - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, StartInterruptionFrame, - OpenAILLMContextFrame, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -563,4 +566,5 @@ class TestOpenAIAssistantContextAggregator( BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = LLMAssistantContextAggregator + AGGREGATOR_CLASS = OpenAIAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]