From 3a5cd17ea3878155cb8018d35fc2b450209ae0bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 14 Aug 2024 20:23:18 -0700 Subject: [PATCH] processors(aggregators): multiple LLM aggregators updates --- .../processors/aggregators/llm_response.py | 75 +++++++++---------- .../aggregators/openai_llm_context.py | 35 ++++++--- src/pipecat/services/anthropic.py | 8 +- src/pipecat/services/cartesia.py | 2 +- src/pipecat/services/openai.py | 7 +- 5 files changed, 66 insertions(+), 61 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 3c0a901af..7c38e62ad 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -82,10 +82,6 @@ class LLMResponseAggregator(FrameProcessor): # # and T2 would be dropped. - async def _set_tools(self, tools: List): - # noop in the base class - pass - async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -129,34 +125,16 @@ class LLMResponseAggregator(FrameProcessor): self._reset() await self.push_frame(frame, direction) elif isinstance(frame, LLMMessagesAppendFrame): - self._messages.extend(frame.messages) - messages_frame = LLMMessagesFrame(self._messages) - await self.push_frame(messages_frame) + self._add_messages(frame.messages) elif isinstance(frame, LLMMessagesUpdateFrame): - # We push the frame downstream so the assistant aggregator gets - # updated as well. - # TODO-CB: Now we're replacing the contents of the array so we - # don't need to push the frame here - # await self.push_frame(frame) - # We can now reset this one. - self._reset() self._set_messages(frame.messages) - # messages_frame = LLMMessagesFrame(self._messages) - # await self.push_frame(messages_frame) - await self.push_messages_frame() elif isinstance(frame, LLMSetToolsFrame): - await self.push_frame(frame) - await self._set_tools(frame.tools) + self._set_tools(frame.tools) else: await self.push_frame(frame, direction) if send_aggregation: await self._push_aggregation() - - # TODO-CB: Types - def _set_messages(self, messages): - self._messages.clear() - self._messages.extend(messages) async def _push_aggregation(self): if len(self._aggregation) > 0: @@ -169,6 +147,19 @@ class LLMResponseAggregator(FrameProcessor): frame = LLMMessagesFrame(self._messages) await self.push_frame(frame) + # TODO-CB: Types + def _add_messages(self, messages): + self._messages.extend(messages) + + def _set_messages(self, messages): + self._reset() + self._messages.clear() + self._messages.extend(messages) + + def _set_tools(self, tools): + # noop in the base class + pass + def _reset(self): self._aggregation = "" self._aggregating = False @@ -257,23 +248,29 @@ class LLMFullResponseAggregator(FrameProcessor): class LLMContextAggregator(LLMResponseAggregator): def __init__(self, *, context: OpenAILLMContext, **kwargs): - - self._context = context super().__init__(**kwargs) - # TODO-CB: thanks, I hate it - self._messages = context.messages - - - async def _set_tools(self, tools: List): - # We push the frame downstream so the assistant aggregator gets - # updated as well. - self._context.tools = tools - - # TODO-CB: Types - def _set_messages(self, messages): - self._messages.clear() - self._messages.extend(messages) + self._context = context + @property + def context(self): + return self._context + + def get_context_frame(self) -> OpenAILLMContextFrame: + return OpenAILLMContextFrame(context=self._context) + + async def push_context_frame(self): + frame = self.get_context_frame() + await self.push_frame(frame) + + # TODO-CB: Types + def _add_messages(self, messages): + self._context.add_messages(messages) + + def _set_messages(self, messages): + self._context.set_messages(messages) + + def _set_tools(self, tools: List): + self._context.set_tools(tools) async def _push_aggregation(self): if len(self._aggregation) > 0: diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py index bd88a9580..009040996 100644 --- a/src/pipecat/processors/aggregators/openai_llm_context.py +++ b/src/pipecat/processors/aggregators/openai_llm_context.py @@ -44,10 +44,10 @@ class OpenAILLMContext: tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ): - self.messages: List[ChatCompletionMessageParam] = messages if messages else [ + self._messages: List[ChatCompletionMessageParam] = messages if messages else [ ] - self.tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice - self.tools: List[ChatCompletionToolParam] | NotGiven = tools + self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice + self._tools: List[ChatCompletionToolParam] | NotGiven = tools @staticmethod def from_messages(messages: List[dict]) -> "OpenAILLMContext": @@ -84,26 +84,43 @@ class OpenAILLMContext: }) return context + @property + def messages(self) -> List[ChatCompletionMessageParam]: + return self._messages + + @property + def tools(self) -> List[ChatCompletionToolParam] | NotGiven: + return self._tools + + @property + def tool_choice(self) -> ChatCompletionToolChoiceOptionParam | NotGiven: + return self._tool_choice + def add_message(self, message: ChatCompletionMessageParam): - self.messages.append(message) + self._messages.append(message) + + def add_messages(self, messages: List[ChatCompletionMessageParam]): + self._messages.extend(messages) + + def set_messages(self, messages: List[ChatCompletionMessageParam]): + self._messages[:] = messages def get_messages(self) -> List[ChatCompletionMessageParam]: - return self.messages + return self._messages def get_messages_json(self) -> str: - return json.dumps(self.messages, cls=CustomEncoder) + return json.dumps(self._messages, cls=CustomEncoder) def set_tool_choice( self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven ): - self.tool_choice = tool_choice + self._tool_choice = tool_choice def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN): if tools != NOT_GIVEN and len(tools) == 0: tools = NOT_GIVEN + self._tools = tools - self.tools = tools - async def call_function( self, f: callable, diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 89f13125a..5c636b45f 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -336,10 +336,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator): if isinstance(context, OpenAILLMContext): self._context = AnthropicLLMContext.from_openai_context(context) - async def push_messages_frame(self): - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) - async def process_frame(self, frame, direction): await super().process_frame(frame, direction) # Our parent method has already called push_frame(). So we can't interrupt the @@ -415,7 +411,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): size=frame.user_image_raw_frame.size, image=frame.user_image_raw_frame.image, text=frame.text) - await self._user_context_aggregator.push_messages_frame() + await self._user_context_aggregator.push_context_frame() except Exception as e: logger.error(f"Error processing AnthropicImageMessageFrame: {e}") @@ -465,7 +461,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): self._context.add_message({"role": "assistant", "content": aggregation}) if run_llm: - await self._user_context_aggregator.push_messages_frame() + await self._user_context_aggregator.push_context_frame() except Exception as e: logger.error(f"Error processing frame: {e}") diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index ead8a7ac4..c0f9d0e08 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -177,7 +177,7 @@ class CartesiaTTSService(TTSService): elif msg["type"] == "error": logger.error(f"{self} error: {msg}") await self.stop_all_metrics() - await self.push_frame(ErrorFrame(f'{self} error: {msg["error"]}')) + await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) else: logger.error(f"Cartesia error, unknown message type: {msg}") except asyncio.CancelledError: diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index a27d6954d..8e6ed83f9 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -355,10 +355,6 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator): def __init__(self, context: OpenAILLMContext): super().__init__(context=context) - async def push_messages_frame(self): - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) - class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): def __init__(self, user_context_aggregator: OpenAIUserContextAggregator): @@ -426,8 +422,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): self._context.add_message({"role": "assistant", "content": aggregation}) if run_llm: - - await self._user_context_aggregator.push_messages_frame() + await self._user_context_aggregator.push_context_frame() except Exception as e: logger.error(f"Error processing frame: {e}")