Compare commits

...

1 Commits

Author SHA1 Message Date
Mark Backman
d331649736 Add context aggregation to Google Gemini LLM 2024-09-30 09:29:54 -04:00
2 changed files with 225 additions and 32 deletions

View File

@@ -5,10 +5,14 @@
# #
import asyncio import asyncio
import aiohttp
import os import os
import sys import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.runner import PipelineRunner
@@ -21,12 +25,6 @@ from pipecat.services.google import GoogleLLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer from pipecat.vad.silero import SileroVADAnalyzer
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True) load_dotenv(override=True)
logger.remove(0) logger.remove(0)

View File

@@ -5,30 +5,43 @@
# #
import asyncio import asyncio
import base64
import io
import json
from dataclasses import dataclass
from typing import List from typing import List
from loguru import logger
from PIL import Image
from pipecat.frames.frames import ( from pipecat.frames.frames import (
Frame, Frame,
LLMModelUpdateFrame, FunctionCallInProgressFrame,
TextFrame, FunctionCallResultFrame,
VisionImageRawFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
LLMFullResponseEndFrame, LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMModelUpdateFrame,
StartInterruptionFrame,
TextFrame,
UserImageRawFrame,
UserImageRequestFrame,
VisionImageRawFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
) )
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
from pipecat.processors.aggregators.openai_llm_context import ( from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext, OpenAILLMContext,
OpenAILLMContextFrame, OpenAILLMContextFrame,
) )
from pipecat.processors.frame_processor import FrameDirection
from loguru import logger from pipecat.services.ai_services import LLMService
try: try:
import google.generativeai as gai
import google.ai.generativelanguage as glm import google.ai.generativelanguage as glm
import google.generativeai as gai
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
logger.error(f"Exception: {e}") logger.error(f"Exception: {e}")
logger.error( logger.error(
@@ -37,6 +50,18 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}") raise Exception(f"Missing module: {e}")
@dataclass
class GoogleContextAggregatorPair:
_user: "GoogleUserContextAggregator"
_assistant: "GoogleAssistantContextAggregator"
def user(self) -> "GoogleUserContextAggregator":
return self._user
def assistant(self) -> "GoogleAssistantContextAggregator":
return self._assistant
class GoogleLLMService(LLMService): class GoogleLLMService(LLMService):
"""This class implements inference with Google's AI models """This class implements inference with Google's AI models
@@ -53,6 +78,12 @@ class GoogleLLMService(LLMService):
def can_generate_metrics(self) -> bool: def can_generate_metrics(self) -> bool:
return True return True
@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> GoogleContextAggregatorPair:
user = GoogleUserContextAggregator(context)
assistant = GoogleAssistantContextAggregator(user)
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
def _create_client(self, model: str): def _create_client(self, model: str):
self.set_model_name(model) self.set_model_name(model)
self._client = gai.GenerativeModel(model) self._client = gai.GenerativeModel(model)
@@ -69,16 +100,24 @@ class GoogleLLMService(LLMService):
elif role == "assistant": elif role == "assistant":
role = "model" role = "model"
parts = [glm.Part(text=content)] if isinstance(content, list):
if "mime_type" in message: parts = []
parts.append( for item in content:
glm.Part( if item["type"] == "text":
inline_data=glm.Blob( parts.append(glm.Part(text=item["text"]))
mime_type=message["mime_type"], data=message["data"].getvalue() elif item["type"] == "image_url":
image_data = item["image_url"]["url"].split(",")[1]
parts.append(
glm.Part(
inline_data=glm.Blob(
mime_type="image/jpeg", data=base64.b64decode(image_data)
)
)
) )
) else:
) parts = [glm.Part(text=content)]
google_messages.append({"role": role, "parts": parts})
google_messages.append(glm.Content(role=role, parts=parts))
return google_messages return google_messages
@@ -88,8 +127,10 @@ class GoogleLLMService(LLMService):
await asyncio.sleep(0) await asyncio.sleep(0)
async def _process_context(self, context: OpenAILLMContext): async def _process_context(self, context: OpenAILLMContext):
await self.push_frame(LLMFullResponseStartFrame())
try: try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
logger.debug(f"Generating chat: {context.get_messages_json()}") logger.debug(f"Generating chat: {context.get_messages_json()}")
messages = self._get_messages_from_openai_context(context) messages = self._get_messages_from_openai_context(context)
@@ -116,19 +157,19 @@ class GoogleLLMService(LLMService):
except Exception as e: except Exception as e:
logger.exception(f"{self} exception: {e}") logger.exception(f"{self} exception: {e}")
finally: finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame()) await self.push_frame(LLMFullResponseEndFrame())
async def process_frame(self, frame: Frame, direction: FrameDirection): async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction) await super().process_frame(frame, direction)
context = None context = None
if isinstance(frame, OpenAILLMContextFrame): if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context context = GoogleLLMContext.from_openai_context(frame.context)
elif isinstance(frame, LLMMessagesFrame): elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages) context = GoogleLLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame): elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext.from_image_frame(frame) context = GoogleLLMContext.from_image_frame(frame)
elif isinstance(frame, LLMModelUpdateFrame): elif isinstance(frame, LLMModelUpdateFrame):
logger.debug(f"Switching LLM model to: [{frame.model}]") logger.debug(f"Switching LLM model to: [{frame.model}]")
self._create_client(frame.model) self._create_client(frame.model)
@@ -137,3 +178,157 @@ class GoogleLLMService(LLMService):
if context: if context:
await self._process_context(context) await self._process_context(context)
class GoogleLLMContext(OpenAILLMContext):
def __init__(
self,
messages: list[dict] | None = None,
tools: list[dict] | None = None,
tool_choice: dict | None = None,
):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self._user_image_request_context = {}
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
return cls(
messages=openai_context.messages,
tools=openai_context.tools,
tool_choice=openai_context.tool_choice,
)
@classmethod
def from_messages(cls, messages: List[dict]) -> "GoogleLLMContext":
return cls(messages=messages)
@classmethod
def from_image_frame(cls, frame: VisionImageRawFrame) -> "GoogleLLMContext":
context = cls()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
)
return context
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
content = [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
]
if text:
content.append({"type": "text", "text": text})
self.add_message({"role": "user", "content": content})
class GoogleUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext | GoogleLLMContext):
super().__init__(context=context)
if isinstance(context, OpenAILLMContext):
self._context = GoogleLLMContext.from_openai_context(context)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
try:
if isinstance(frame, UserImageRequestFrame):
if frame.context:
if isinstance(frame.context, str):
self._context._user_image_request_context[frame.user_id] = frame.context
else:
logger.error(
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
)
del self._context._user_image_request_context[frame.user_id]
else:
if frame.user_id in self._context._user_image_request_context:
del self._context._user_image_request_context[frame.user_id]
elif isinstance(frame, UserImageRawFrame):
text = self._context._user_image_request_context.get(frame.user_id) or ""
if text:
del self._context._user_image_request_context[frame.user_id]
# Handle the case where frame.format might be None
image_format = frame.format or "JPEG" # Default to JPEG if format is None
self._context.add_image_frame_message(
format=image_format, size=frame.size, image=frame.image, text=text
)
await self.push_context_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")
class GoogleAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: GoogleUserContextAggregator):
super().__init__(context=user_context_aggregator._context)
self._user_context_aggregator = user_context_aggregator
self._function_call_in_progress = None
self._function_call_result = None
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
self._function_call_in_progress = None
self._function_call_result = None
elif isinstance(frame, FunctionCallInProgressFrame):
self._function_call_in_progress = frame
elif isinstance(frame, FunctionCallResultFrame):
if (
self._function_call_in_progress
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
):
self._function_call_in_progress = None
self._function_call_result = frame
await self._push_aggregation()
else:
logger.warning(
"FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id"
)
self._function_call_in_progress = None
self._function_call_result = None
async def _push_aggregation(self):
if not (self._aggregation or self._function_call_result):
return
run_llm = False
aggregation = self._aggregation
self._aggregation = ""
try:
if self._function_call_result:
frame = self._function_call_result
self._function_call_result = None
if frame.result:
self._context.add_message(
{
"role": "assistant",
"content": aggregation,
"function_call": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments),
},
}
)
self._context.add_message(
{
"role": "function",
"content": json.dumps(frame.result),
"name": frame.function_name,
}
)
run_llm = True
else:
self._context.add_message({"role": "assistant", "content": aggregation})
if run_llm:
await self._user_context_aggregator.push_context_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")