initialize assistant aggregators with context and push upstream instead
This commit is contained in:
@@ -173,9 +173,9 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
def get_context_frame(self) -> OpenAILLMContextFrame:
|
||||
return OpenAILLMContextFrame(context=self._context)
|
||||
|
||||
async def push_context_frame(self):
|
||||
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
frame = self.get_context_frame()
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def add_messages(self, messages):
|
||||
self._context.add_messages(messages)
|
||||
|
||||
@@ -126,9 +126,11 @@ class AnthropicLLMService(LLMService):
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> AnthropicContextAggregatorPair:
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = AnthropicLLMContext.from_openai_context(context)
|
||||
user = AnthropicUserContextAggregator(context)
|
||||
assistant = AnthropicAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
context, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -654,9 +656,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs):
|
||||
super().__init__(context=context, **kwargs)
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
self._context = AnthropicLLMContext.from_openai_context(context)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
@@ -703,9 +702,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs):
|
||||
super().__init__(context=context, **kwargs)
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
self._pending_image_frame_message = None
|
||||
@@ -799,7 +797,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
|
||||
@@ -626,7 +626,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
|
||||
@@ -17,6 +17,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMService,
|
||||
@@ -91,7 +92,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
|
||||
@@ -355,7 +355,7 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = OpenAIAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
context, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -592,9 +592,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
def __init__(self, context: OpenAILLMContext, **kwargs):
|
||||
super().__init__(context=context, **kwargs)
|
||||
self._function_calls_in_progress = {}
|
||||
self._function_call_result = None
|
||||
self._pending_image_frame_message = None
|
||||
@@ -686,7 +685,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
|
||||
@@ -217,8 +217,8 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
|
||||
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
|
||||
# special frame to do that.
|
||||
await self._user_context_aggregator.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame)
|
||||
await self.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
if properties and properties.run_llm is not None:
|
||||
# If the tool call result has a run_llm property, use it
|
||||
@@ -228,7 +228,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
|
||||
@@ -11,6 +11,7 @@ from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -25,7 +26,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.services.openai import OpenAIUserContextAggregator
|
||||
from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
AGGREGATION_TIMEOUT = 0.1
|
||||
@@ -37,6 +38,7 @@ BOT_INTERRUPTION_SLEEP = 0.25
|
||||
class BaseTestUserContextAggregator:
|
||||
CONTEXT_CLASS = None # To be set in subclasses
|
||||
AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
|
||||
async def test_se(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
@@ -67,7 +69,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -92,7 +94,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -123,7 +125,7 @@ class BaseTestUserContextAggregator:
|
||||
UserStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -149,7 +151,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -176,7 +178,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -200,7 +202,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -225,7 +227,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -251,8 +253,8 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -278,7 +280,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -306,7 +308,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -325,7 +327,7 @@ class BaseTestUserContextAggregator:
|
||||
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [OpenAILLMContextFrame]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -347,7 +349,7 @@ class BaseTestUserContextAggregator:
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [OpenAILLMContextFrame]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -380,7 +382,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
expected_up_frames = [BotInterruptionFrame]
|
||||
await run_test(
|
||||
@@ -395,6 +397,7 @@ class BaseTestUserContextAggregator:
|
||||
class BaseTestAssistantContextAggreagator:
|
||||
CONTEXT_CLASS = None # To be set in subclasses
|
||||
AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
|
||||
async def test_empty(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
@@ -421,7 +424,7 @@ class BaseTestAssistantContextAggreagator:
|
||||
TextFrame(text="Hello Pipecat!"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [OpenAILLMContextFrame]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -443,7 +446,7 @@ class BaseTestAssistantContextAggreagator:
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [OpenAILLMContextFrame]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -465,7 +468,7 @@ class BaseTestAssistantContextAggreagator:
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [OpenAILLMContextFrame]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -489,7 +492,7 @@ class BaseTestAssistantContextAggreagator:
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [OpenAILLMContextFrame, OpenAILLMContextFrame]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -517,9 +520,9 @@ class BaseTestAssistantContextAggreagator:
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
StartInterruptionFrame,
|
||||
OpenAILLMContextFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -563,4 +566,5 @@ class TestOpenAIAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = LLMAssistantContextAggregator
|
||||
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
Reference in New Issue
Block a user