initialize assistant aggregators with context and push upstream instead

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-13 11:48:04 -08:00
parent 84510fd521
commit 463078e375
7 changed files with 44 additions and 42 deletions

View File

@@ -173,9 +173,9 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
def get_context_frame(self) -> OpenAILLMContextFrame:
return OpenAILLMContextFrame(context=self._context)
async def push_context_frame(self):
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
frame = self.get_context_frame()
await self.push_frame(frame)
await self.push_frame(frame, direction)
def add_messages(self, messages):
self._context.add_messages(messages)

View File

@@ -126,9 +126,11 @@ class AnthropicLLMService(LLMService):
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> AnthropicContextAggregatorPair:
if isinstance(context, OpenAILLMContext):
context = AnthropicLLMContext.from_openai_context(context)
user = AnthropicUserContextAggregator(context)
assistant = AnthropicAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
context, expect_stripped_words=assistant_expect_stripped_words
)
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
@@ -654,9 +656,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs):
super().__init__(context=context, **kwargs)
if isinstance(context, OpenAILLMContext):
self._context = AnthropicLLMContext.from_openai_context(context)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# Our parent method has already called push_frame(). So we can't interrupt the
@@ -703,9 +702,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs):
super().__init__(context=context, **kwargs)
self._function_call_in_progress = None
self._function_call_result = None
self._pending_image_frame_message = None
@@ -799,7 +797,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:

View File

@@ -626,7 +626,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:

View File

@@ -17,6 +17,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai import (
OpenAIAssistantContextAggregator,
OpenAILLMService,
@@ -91,7 +92,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:

View File

@@ -355,7 +355,7 @@ class OpenAILLMService(BaseOpenAILLMService):
) -> OpenAIContextAggregatorPair:
user = OpenAIUserContextAggregator(context)
assistant = OpenAIAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
context, expect_stripped_words=assistant_expect_stripped_words
)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
@@ -592,9 +592,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(context=context, **kwargs)
self._function_calls_in_progress = {}
self._function_call_result = None
self._pending_image_frame_message = None
@@ -686,7 +685,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:

View File

@@ -217,8 +217,8 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
# special frame to do that.
await self._user_context_aggregator.push_frame(
RealtimeFunctionCallResultFrame(result_frame=frame)
await self.push_frame(
RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
)
if properties and properties.run_llm is not None:
# If the tool call result has a run_llm property, use it
@@ -228,7 +228,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
run_llm = not bool(self._function_calls_in_progress)
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:

View File

@@ -11,6 +11,7 @@ from pipecat.frames.frames import (
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
OpenAILLMContextAssistantTimestampFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
@@ -25,7 +26,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.services.openai import OpenAIUserContextAggregator
from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator
from pipecat.tests.utils import SleepFrame, run_test
AGGREGATION_TIMEOUT = 0.1
@@ -37,6 +38,7 @@ BOT_INTERRUPTION_SLEEP = 0.25
class BaseTestUserContextAggregator:
CONTEXT_CLASS = None # To be set in subclasses
AGGREGATOR_CLASS = None # To be set in subclasses
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
async def test_se(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
@@ -67,7 +69,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -92,7 +94,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -123,7 +125,7 @@ class BaseTestUserContextAggregator:
UserStoppedSpeakingFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -149,7 +151,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -176,7 +178,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -200,7 +202,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -225,7 +227,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -251,8 +253,8 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -278,7 +280,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -306,7 +308,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -325,7 +327,7 @@ class BaseTestUserContextAggregator:
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [OpenAILLMContextFrame]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_up_frames = [BotInterruptionFrame]
await run_test(
aggregator,
@@ -347,7 +349,7 @@ class BaseTestUserContextAggregator:
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [OpenAILLMContextFrame]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_up_frames = [BotInterruptionFrame]
await run_test(
aggregator,
@@ -380,7 +382,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
expected_up_frames = [BotInterruptionFrame]
await run_test(
@@ -395,6 +397,7 @@ class BaseTestUserContextAggregator:
class BaseTestAssistantContextAggreagator:
CONTEXT_CLASS = None # To be set in subclasses
AGGREGATOR_CLASS = None # To be set in subclasses
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
async def test_empty(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
@@ -421,7 +424,7 @@ class BaseTestAssistantContextAggreagator:
TextFrame(text="Hello Pipecat!"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [OpenAILLMContextFrame]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -443,7 +446,7 @@ class BaseTestAssistantContextAggreagator:
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [OpenAILLMContextFrame]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -465,7 +468,7 @@ class BaseTestAssistantContextAggreagator:
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [OpenAILLMContextFrame]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -489,7 +492,7 @@ class BaseTestAssistantContextAggreagator:
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [OpenAILLMContextFrame, OpenAILLMContextFrame]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -517,9 +520,9 @@ class BaseTestAssistantContextAggreagator:
LLMFullResponseEndFrame(),
]
expected_down_frames = [
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
StartInterruptionFrame,
OpenAILLMContextFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -563,4 +566,5 @@ class TestOpenAIAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = LLMAssistantContextAggregator
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]