From 6c58a5b024cfe56480ae95ffbcff9aa242a69707 Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Sat, 7 Sep 2024 12:38:17 -0700 Subject: [PATCH] tiny bit more cleanup --- src/pipecat/services/together.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py index d8379abad..fc3e02185 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together.py @@ -18,7 +18,7 @@ from pipecat.frames.frames import ( Frame, LLMModelUpdateFrame, TextFrame, - UserImageRawFrame, + UserImageRequestFrame, LLMMessagesFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame, @@ -249,17 +249,9 @@ class TogetherUserContextAggregator(LLMUserContextAggregator): if isinstance(context, OpenAILLMContext): self._context = TogetherLLMContext.from_openai_context(context) - def get_messages_frame(self): - return OpenAILLMContextFrame(self._context) - async def push_messages_frame(self): - await self.push_frame(self.get_messages_frame()) - - def append_image_description_tool_message(self, description): - self._context.add_message({ - "role": "tool", - "content": json.dumps({"image_description": description}) - }) + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) async def process_frame(self, frame, direction): await super().process_frame(frame, direction) @@ -268,10 +260,20 @@ class TogetherUserContextAggregator(LLMUserContextAggregator): # to talk through (tagging @aleix). At some point we might need to refactor these # context aggregators. try: - if isinstance(frame, UserImageRawFrame): - if frame.description: - self.append_image_description_tool_message(frame.description) - await self.push_messages_frame() + if isinstance(frame, UserImageRequestFrame): + # The LLM sends a UserImageRequestFrame upstream. Cache any context provided with + # that frame so we can use it when we assemble the image message in the assistant + # context aggregator. + if (frame.context): + if isinstance(frame.context, str): + self._context._user_image_request_context[frame.user_id] = frame.context + else: + logger.error( + f"Unexpected UserImageRequestFrame context type: {type(frame.context)}") + del self._context._user_image_request_context[frame.user_id] + else: + if frame.user_id in self._context._user_image_request_context: + del self._context._user_image_request_context[frame.user_id] except Exception as e: logger.error(f"Error processing frame: {e}")