Merge pull request #652 from pipecat-ai/khk/more-gemini

Gemini new context manager and rewrite to use google data structures internally
This commit is contained in:
Mark Backman
2024-10-24 13:47:38 -04:00
committed by GitHub
5 changed files with 754 additions and 8 deletions

View File

@@ -0,0 +1,173 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import aiohttp
import os
import sys
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.google import GoogleLLMService
from pipecat.services.openai import OpenAILLMContext
from pipecat.transports.services.daily import DailyParams, DailyTransport
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
video_participant_id = None
async def get_weather(function_name, tool_call_id, arguments, llm, context, result_callback):
location = arguments["location"]
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback):
logger.debug(f"!!! IN get_image {video_participant_id}, {arguments}")
question = arguments["question"]
await llm.request_image_frame(user_id=video_participant_id, text_content=question)
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY"))
llm.register_function("get_weather", get_weather)
llm.register_function("get_image", get_image)
tools = [
{
"function_declarations": [
{
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
{
"name": "get_image",
"description": "Get and image from the camera or video stream.",
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to to use when running inference on the acquired image.",
},
},
"required": ["question"],
},
},
]
}
]
system_prompt = """\
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
Your response will be turned into speech so use only simple words and punctuation.
You have access to two tools: get_weather and get_image.
You can respond to questions about the weather using the get_weather tool.
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
indicate you should use the get_image tool are:
- What do you see?
- What's in the video?
- Can you describe the video?
- Tell me about what you see.
- Tell me something interesting about what you see.
- What's happening in the video?
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Say hello."},
]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(),
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
global video_participant_id
video_participant_id = participant["id"]
transport.capture_participant_transcription(participant["id"])
transport.capture_participant_video(video_participant_id, framerate=0)
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,290 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import glob
import json
import os
import sys
from datetime import datetime
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
)
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.google import GoogleLLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
video_participant_id = None
BASE_FILENAME = "/tmp/pipecat_conversation_"
tts = None
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
temperature = 75 if args["format"] == "fahrenheit" else 24
await result_callback(
{
"conditions": "nice",
"temperature": temperature,
"format": args["format"],
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
}
)
async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback):
question = arguments["question"]
await llm.request_image_frame(user_id=video_participant_id, text_content=question)
async def get_saved_conversation_filenames(
function_name, tool_call_id, args, llm, context, result_callback
):
# Construct the full pattern including the BASE_FILENAME
full_pattern = f"{BASE_FILENAME}*.json"
# Use glob to find all matching files
matching_files = glob.glob(full_pattern)
logger.debug(f"matching files: {matching_files}")
await result_callback({"filenames": matching_files})
async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback):
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
filename = f"{BASE_FILENAME}{timestamp}.json"
logger.debug(
f"writing conversation to {filename}\n{json.dumps(context.get_messages_for_logging(), indent=4)}"
)
try:
with open(filename, "w") as file:
# todo: extract 'system' into the first message in the list
messages = context.get_messages_for_persistent_storage()
# remove the last message (the instruction to save the context)
messages.pop()
json.dump(messages, file, indent=2)
await result_callback({"success": True})
except Exception as e:
logger.debug(f"error saving conversation: {e}")
await result_callback({"success": False, "error": str(e)})
async def load_conversation(function_name, tool_call_id, args, llm, context, result_callback):
global tts
filename = args["filename"]
logger.debug(f"loading conversation from {filename}")
try:
with open(filename, "r") as file:
context.set_messages(json.load(file))
await result_callback(
{
"success": True,
"message": "The most recent conversation has been loaded. Awaiting further instructions.",
}
)
except Exception as e:
await result_callback({"success": False, "error": str(e)})
# Test message munging ...
messages = [
{
"role": "system",
"content": """You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your
capabilities in a succinct way. Your output will be converted to audio so don't include special
characters in your answers. Respond to what the user said in a creative and helpful way.
You have several tools you can use to help you.
You can respond to questions about the weather using the get_weather tool.
You can save the current conversation using the save_conversation tool. This tool allows you to save
the current conversation to external storage. If the user asks you to save the conversation, use this
save_conversation too.
You can load a saved conversation using the load_conversation tool. This tool allows you to load a
conversation from external storage. You can get a list of conversations that have been saved using the
get_saved_conversation_filenames tool.
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
indicate you should use the get_image tool are:
- What do you see?
- What's in the video?
- Can you describe the video?
- Tell me about what you see.
- Tell me something interesting about what you see.
- What's happening in the video?
""",
},
# {"role": "user", "content": ""},
# {"role": "assistant", "content": []},
# {"role": "user", "content": "Tell me"},
# {"role": "user", "content": "a joke"},
]
tools = [
{
"function_declarations": [
{
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
{
"name": "save_conversation",
"description": "Save the current conversation. Use this function to persist the current conversation to external storage.",
"parameters": {
"type": "object",
"properties": {
"user_request_text": {
"type": "string",
"description": "The text of the user's request to save the conversation.",
}
},
"required": ["user_request_text"],
},
},
{
"name": "get_saved_conversation_filenames",
"description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.",
"parameters": None,
},
{
"name": "load_conversation",
"description": "Load a conversation history. Use this function to load a conversation history into the current session.",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "The filename of the conversation history to load.",
}
},
"required": ["filename"],
},
},
{
"name": "get_image",
"description": "Get and image from the camera or video stream.",
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to to use when running inference on the acquired image.",
},
},
"required": ["question"],
},
},
]
},
]
async def main():
global tts
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = GoogleLLMService(model="gemini-1.5-flash-latest", api_key=os.getenv("GOOGLE_API_KEY"))
# you can either register a single function for all function calls, or specific functions
# llm.register_function(None, fetch_weather_from_api)
llm.register_function("get_current_weather", fetch_weather_from_api)
llm.register_function("save_conversation", save_conversation)
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
llm.register_function("load_conversation", load_conversation)
llm.register_function("get_image", get_image)
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(),
llm, # LLM
tts,
context_aggregator.assistant(),
transport.output(), # Transport bot output
]
)
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
# report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
global video_participant_id
video_participant_id = participant["id"]
transport.capture_participant_transcription(participant["id"])
transport.capture_participant_video(video_participant_id, framerate=0)
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -48,7 +48,7 @@ elevenlabs = [ "websockets~=13.1" ]
examples = [ "python-dotenv~=1.0.1", "flask~=3.0.3", "flask_cors~=4.0.1" ]
fal = [ "fal-client~=0.4.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-generativeai~=0.7.2", "google-cloud-texttospeech~=2.17.2" ]
google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.17.2" ]
gstreamer = [ "pygobject~=3.48.2" ]
fireworks = [ "openai~=1.37.2" ]
langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ]

View File

@@ -70,6 +70,8 @@ class OpenAILLMContext:
context.add_message(message)
return context
# todo: deprecate from_image_frame. It's only used to create a single-use
# context, which isn't useful for most real-world applications.
@staticmethod
def from_image_frame(frame: VisionImageRawFrame) -> "OpenAILLMContext":
"""
@@ -77,6 +79,10 @@ class OpenAILLMContext:
expects images to be base64 encoded, but other vision models may not.
So we'll store the image as bytes and do the base64 encoding as needed
in the LLM service.
NOTE: the above only applies to the deprecated use of this method. The
add_image_frame_message() below does the base64 encoding as expected
in the OpenAI format.
"""
context = OpenAILLMContext()
buffer = io.BytesIO()

View File

@@ -5,10 +5,15 @@
#
import asyncio
import base64
from dataclasses import dataclass
import json
import io
from typing import AsyncGenerator, List, Literal, Optional
from loguru import logger
from PIL import Image
from pydantic import BaseModel
from pipecat.frames.frames import (
@@ -28,6 +33,10 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.services.openai import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService, TTSService
from pipecat.transcriptions.language import Language
@@ -45,6 +54,249 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
async def _push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message(
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
)
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
# Reset our accumulator state.
self._reset()
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def _push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
return
run_llm = False
aggregation = self._aggregation
self._reset()
try:
if self._function_call_result:
frame = self._function_call_result
self._function_call_result = None
if frame.result:
logger.debug(f"FunctionCallResultFrame result: {frame.arguments}")
self._context.add_message(
glm.Content(
role="model",
parts=[
glm.Part(
function_call=glm.FunctionCall(
name=frame.function_name, args=frame.arguments
)
)
],
)
)
response = frame.result
if isinstance(response, str):
response = {"response": response}
self._context.add_message(
glm.Content(
role="user",
parts=[
glm.Part(
function_response=glm.FunctionResponse(
name=frame.function_name, response=response
)
)
],
)
)
run_llm = not bool(self._function_calls_in_progress)
else:
self._context.add_message(
glm.Content(role="model", parts=[glm.Part(text=aggregation)])
)
if self._pending_image_frame_message:
frame = self._pending_image_frame_message
self._pending_image_frame_message = None
self._context.add_image_frame_message(
format=frame.user_image_raw_frame.format,
size=frame.user_image_raw_frame.size,
image=frame.user_image_raw_frame.image,
text=frame.text,
)
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
except Exception as e:
logger.exception(f"Error processing frame: {e}")
@dataclass
class GoogleContextAggregatorPair:
_user: "GoogleUserContextAggregator"
_assistant: "GoogleAssistantContextAggregator"
def user(self) -> "GoogleUserContextAggregator":
return self._user
def assistant(self) -> "GoogleAssistantContextAggregator":
return self._assistant
class GoogleLLMContext(OpenAILLMContext):
@staticmethod
def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext":
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext):
logger.debug(f"Upgrading to Google: {obj}")
obj.__class__ = GoogleLLMContext
obj._restructure_from_openai_messages()
return obj
def set_messages(self, messages: List):
self._messages[:] = messages
self._restructure_from_openai_messages()
def get_messages_for_logging(self):
msgs = []
for message in self.messages:
obj = glm.Content.to_dict(message)
try:
if "parts" in obj:
for part in obj["parts"]:
if "inline_data" in part:
part["inline_data"]["data"] = "..."
except Exception as e:
logger.debug(f"Error: {e}")
msgs.append(obj)
return msgs
def from_standard_message(self, message):
role = message["role"]
content = message.get("content", [])
if role == "system":
role = "user"
elif role == "assistant":
role = "model"
parts = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
parts.append(
glm.Part(
function_call=glm.FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
)
)
elif role == "tool":
role = "model"
parts.append(
glm.Part(
function_response=glm.FunctionResponse(
name="tool_call_result", # seems to work to hard-code the same name every time
response=json.loads(message["content"]),
)
)
)
elif isinstance(content, str):
parts.append(glm.Part(text=content))
elif isinstance(content, list):
for c in content:
if c["type"] == "text":
parts.append(glm.Part(text=c["text"]))
elif c["type"] == "image_url":
parts.append(
glm.Part(
inline_data=glm.Blob(
mime_type="image/jpeg",
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
)
)
)
message = glm.Content(role=role, parts=parts)
return message
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
parts = []
if text:
parts.append(glm.Part(text=text))
parts.append(
glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())),
)
self.add_message(glm.Content(role="user", parts=parts))
def to_standard_messages(self, obj) -> list:
msg = {"role": obj.role, "content": []}
if msg["role"] == "model":
msg["role"] = "assistant"
for part in obj.parts:
if part.text:
msg["content"].append({"type": "text", "text": part.text})
elif part.inline_data:
encoded = base64.b64encode(part.inline_data.data).decode("utf-8")
msg["content"].append(
{
"type": "image_url",
"image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"},
}
)
elif part.function_call:
args = type(part.function_call).to_dict(part.function_call).get("args", {})
msg["tool_calls"] = [
{
"id": part.function_call.name,
"type": "function",
"function": {
"name": part.function_call.name,
"arguments": json.dumps(args),
},
}
]
elif part.function_response:
msg["role"] = "tool"
resp = (
type(part.function_response).to_dict(part.function_response).get("response", {})
)
msg["tool_call_id"] = part.function_response.name
msg["content"] = json.dumps(resp)
# there might be no content parts for tool_calls messages
if not msg["content"]:
del msg["content"]
return [msg]
def _restructure_from_openai_messages(self):
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
try:
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
except Exception as e:
logger.error(f"Error mapping messages: {e}")
# iterate over messages and remove any messages that have an empty content list
self._messages = [m for m in self._messages if m.parts]
class GoogleLLMService(LLMService):
"""This class implements inference with Google's AI models
@@ -98,20 +350,34 @@ class GoogleLLMService(LLMService):
async def _process_context(self, context: OpenAILLMContext):
await self.push_frame(LLMFullResponseStartFrame())
try:
logger.debug(f"Generating chat: {context.get_messages_json()}")
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
messages = self._get_messages_from_openai_context(context)
# todo: move this into the new context code structure, convert from openai context one time
# todo: add system instructions
# messages = self._get_messages_from_openai_context(context)
messages = context.messages
await self.start_ttfb_metrics()
response = self._client.generate_content(messages, stream=True)
tools = context.tools if context.tools else []
response = self._client.generate_content(contents=messages, tools=tools, stream=True)
await self.stop_ttfb_metrics()
async for chunk in self._async_generator_wrapper(response):
# todo: usage
try:
text = chunk.text
await self.push_frame(TextFrame(text))
for c in chunk.parts:
if c.text:
await self.push_frame(TextFrame(c.text))
elif c.function_call:
args = type(c.function_call).to_dict(c.function_call).get("args", {})
await self.call_function(
context=context,
tool_call_id="what_should_this_be",
function_name=c.function_call.name,
arguments=args,
)
except Exception as e:
# Google LLMs seem to flag safety issues a lot!
if chunk.candidates[0].finish_reason == 3:
@@ -132,10 +398,11 @@ class GoogleLLMService(LLMService):
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
context: GoogleLLMContext = GoogleLLMContext.upgrade_to_google(frame.context)
elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
context = GoogleLLMContext(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
# todo: fix this
context = OpenAILLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
@@ -145,6 +412,16 @@ class GoogleLLMService(LLMService):
if context:
await self._process_context(context)
@staticmethod
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> GoogleContextAggregatorPair:
user = GoogleUserContextAggregator(context)
assistant = GoogleAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
)
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
class GoogleTTSService(TTSService):
class InputParams(BaseModel):