tiny bit more cleanup

This commit is contained in:
Kwindla Hultman Kramer
2024-09-07 12:38:17 -07:00
parent 37bbb687de
commit 6c58a5b024

View File

@@ -18,7 +18,7 @@ from pipecat.frames.frames import (
Frame,
LLMModelUpdateFrame,
TextFrame,
UserImageRawFrame,
UserImageRequestFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
LLMFullResponseEndFrame,
@@ -249,17 +249,9 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
if isinstance(context, OpenAILLMContext):
self._context = TogetherLLMContext.from_openai_context(context)
def get_messages_frame(self):
return OpenAILLMContextFrame(self._context)
async def push_messages_frame(self):
await self.push_frame(self.get_messages_frame())
def append_image_description_tool_message(self, description):
self._context.add_message({
"role": "tool",
"content": json.dumps({"image_description": description})
})
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
@@ -268,10 +260,20 @@ class TogetherUserContextAggregator(LLMUserContextAggregator):
# to talk through (tagging @aleix). At some point we might need to refactor these
# context aggregators.
try:
if isinstance(frame, UserImageRawFrame):
if frame.description:
self.append_image_description_tool_message(frame.description)
await self.push_messages_frame()
if isinstance(frame, UserImageRequestFrame):
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
# that frame so we can use it when we assemble the image message in the assistant
# context aggregator.
if (frame.context):
if isinstance(frame.context, str):
self._context._user_image_request_context[frame.user_id] = frame.context
else:
logger.error(
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}")
del self._context._user_image_request_context[frame.user_id]
else:
if frame.user_id in self._context._user_image_request_context:
del self._context._user_image_request_context[frame.user_id]
except Exception as e:
logger.error(f"Error processing frame: {e}")