Compare commits

...

1 Commits

Author SHA1 Message Date
Mark Backman
d331649736 Add context aggregation to Google Gemini LLM 2024-09-30 09:29:54 -04:00
2 changed files with 225 additions and 32 deletions

View File

@@ -5,10 +5,14 @@
#
import asyncio
import aiohttp
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@@ -21,12 +25,6 @@ from pipecat.services.google import GoogleLLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)

View File

@@ -5,30 +5,43 @@
#
import asyncio
import base64
import io
import json
from dataclasses import dataclass
from typing import List
from loguru import logger
from PIL import Image
from pipecat.frames.frames import (
Frame,
LLMModelUpdateFrame,
TextFrame,
VisionImageRawFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMModelUpdateFrame,
StartInterruptionFrame,
TextFrame,
UserImageRawFrame,
UserImageRequestFrame,
VisionImageRawFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from loguru import logger
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
try:
import google.generativeai as gai
import google.ai.generativelanguage as glm
import google.generativeai as gai
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
@@ -37,6 +50,18 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
@dataclass
class GoogleContextAggregatorPair:
_user: "GoogleUserContextAggregator"
_assistant: "GoogleAssistantContextAggregator"
def user(self) -> "GoogleUserContextAggregator":
return self._user
def assistant(self) -> "GoogleAssistantContextAggregator":
return self._assistant
class GoogleLLMService(LLMService):
"""This class implements inference with Google's AI models
@@ -53,6 +78,12 @@ class GoogleLLMService(LLMService):
def can_generate_metrics(self) -> bool:
return True
@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> GoogleContextAggregatorPair:
user = GoogleUserContextAggregator(context)
assistant = GoogleAssistantContextAggregator(user)
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
def _create_client(self, model: str):
self.set_model_name(model)
self._client = gai.GenerativeModel(model)
@@ -69,16 +100,24 @@ class GoogleLLMService(LLMService):
elif role == "assistant":
role = "model"
parts = [glm.Part(text=content)]
if "mime_type" in message:
parts.append(
glm.Part(
inline_data=glm.Blob(
mime_type=message["mime_type"], data=message["data"].getvalue()
if isinstance(content, list):
parts = []
for item in content:
if item["type"] == "text":
parts.append(glm.Part(text=item["text"]))
elif item["type"] == "image_url":
image_data = item["image_url"]["url"].split(",")[1]
parts.append(
glm.Part(
inline_data=glm.Blob(
mime_type="image/jpeg", data=base64.b64decode(image_data)
)
)
)
)
)
google_messages.append({"role": role, "parts": parts})
else:
parts = [glm.Part(text=content)]
google_messages.append(glm.Content(role=role, parts=parts))
return google_messages
@@ -88,8 +127,10 @@ class GoogleLLMService(LLMService):
await asyncio.sleep(0)
async def _process_context(self, context: OpenAILLMContext):
await self.push_frame(LLMFullResponseStartFrame())
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
logger.debug(f"Generating chat: {context.get_messages_json()}")
messages = self._get_messages_from_openai_context(context)
@@ -116,19 +157,19 @@ class GoogleLLMService(LLMService):
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
context = GoogleLLMContext.from_openai_context(frame.context)
elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
context = GoogleLLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext.from_image_frame(frame)
context = GoogleLLMContext.from_image_frame(frame)
elif isinstance(frame, LLMModelUpdateFrame):
logger.debug(f"Switching LLM model to: [{frame.model}]")
self._create_client(frame.model)
@@ -137,3 +178,157 @@ class GoogleLLMService(LLMService):
if context:
await self._process_context(context)
class GoogleLLMContext(OpenAILLMContext):
def __init__(
self,
messages: list[dict] | None = None,
tools: list[dict] | None = None,
tool_choice: dict | None = None,
):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self._user_image_request_context = {}
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
return cls(
messages=openai_context.messages,
tools=openai_context.tools,
tool_choice=openai_context.tool_choice,
)
@classmethod
def from_messages(cls, messages: List[dict]) -> "GoogleLLMContext":
return cls(messages=messages)
@classmethod
def from_image_frame(cls, frame: VisionImageRawFrame) -> "GoogleLLMContext":
context = cls()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
)
return context
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
content = [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
]
if text:
content.append({"type": "text", "text": text})
self.add_message({"role": "user", "content": content})
class GoogleUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext | GoogleLLMContext):
super().__init__(context=context)
if isinstance(context, OpenAILLMContext):
self._context = GoogleLLMContext.from_openai_context(context)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
try:
if isinstance(frame, UserImageRequestFrame):
if frame.context:
if isinstance(frame.context, str):
self._context._user_image_request_context[frame.user_id] = frame.context
else:
logger.error(
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
)
del self._context._user_image_request_context[frame.user_id]
else:
if frame.user_id in self._context._user_image_request_context:
del self._context._user_image_request_context[frame.user_id]
elif isinstance(frame, UserImageRawFrame):
text = self._context._user_image_request_context.get(frame.user_id) or ""
if text:
del self._context._user_image_request_context[frame.user_id]
# Handle the case where frame.format might be None
image_format = frame.format or "JPEG" # Default to JPEG if format is None
self._context.add_image_frame_message(
format=image_format, size=frame.size, image=frame.image, text=text
)
await self.push_context_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")
class GoogleAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: GoogleUserContextAggregator):
super().__init__(context=user_context_aggregator._context)
self._user_context_aggregator = user_context_aggregator
self._function_call_in_progress = None
self._function_call_result = None
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
self._function_call_in_progress = None
self._function_call_result = None
elif isinstance(frame, FunctionCallInProgressFrame):
self._function_call_in_progress = frame
elif isinstance(frame, FunctionCallResultFrame):
if (
self._function_call_in_progress
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
):
self._function_call_in_progress = None
self._function_call_result = frame
await self._push_aggregation()
else:
logger.warning(
"FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id"
)
self._function_call_in_progress = None
self._function_call_result = None
async def _push_aggregation(self):
if not (self._aggregation or self._function_call_result):
return
run_llm = False
aggregation = self._aggregation
self._aggregation = ""
try:
if self._function_call_result:
frame = self._function_call_result
self._function_call_result = None
if frame.result:
self._context.add_message(
{
"role": "assistant",
"content": aggregation,
"function_call": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
}
)
self._context.add_message(
{
"role": "function",
"content": json.dumps(frame.result),
"name": frame.function_name,
}
)
run_llm = True
else:
self._context.add_message({"role": "assistant", "content": aggregation})
if run_llm:
await self._user_context_aggregator.push_context_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")