Compare commits

...

2 Commits

Author SHA1 Message Date
Kwindla Hultman Kramer
05f44bf4c3 fixedup gpt4o and gemini-flash examples with interruption changes 2024-05-18 15:01:34 -07:00
Kwindla Hultman Kramer
6d8dc732e1 Initial commit of Google Gemini LLM service.
Gemini text input works. We translate from OpenAILLMContext format
on the fly in the GoogleLLMService implementation. This commit also
implements image input (vision) in both the GoogleLLMService and in
the OpenAILLMService. Image input is a hack and needs to be revisited.
OpenAI expects images to be uploaded as base64-encoded JPEGs. Google
does not require the base64 encoding. Other than for images, we use
the OpenAI format as our standard, but base64-encoding the images
and then unencoding them in the GoogleLLMService feels wasteful.
2024-05-18 14:49:54 -07:00
7 changed files with 458 additions and 36 deletions

View File

@@ -0,0 +1,110 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import aiohttp
import os
import sys
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.user_response import UserResponseAggregator
from pipecat.processors.aggregators.vision_image_frame import VisionImageFrameAggregator
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.elevenlabs import ElevenLabsTTSService
from pipecat.services.google import GoogleLLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class UserImageRequester(FrameProcessor):
def __init__(self, participant_id: str | None = None):
super().__init__()
self._participant_id = participant_id
def set_participant_id(self, participant_id: str):
self._participant_id = participant_id
async def process_frame(self, frame: Frame, direction: FrameDirection):
if self._participant_id and isinstance(frame, TextFrame):
await self.push_frame(UserImageRequestFrame(self._participant_id), FrameDirection.UPSTREAM)
await self.push_frame(frame, direction)
async def main(room_url: str, token):
async with aiohttp.ClientSession() as session:
transport = DailyTransport(
room_url,
token,
"Describe participant video",
DailyParams(
audio_in_enabled=True, # This is so Silero VAD can get audio data
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer()
)
)
tts = ElevenLabsTTSService(
aiohttp_session=session,
api_key=os.getenv("ELEVENLABS_API_KEY"),
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)
user_response = UserResponseAggregator()
image_requester = UserImageRequester()
vision_aggregator = VisionImageFrameAggregator()
google = GoogleLLMService(model="gemini-1.5-flash-latest")
tts = ElevenLabsTTSService(
aiohttp_session=session,
api_key=os.getenv("ELEVENLABS_API_KEY"),
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await tts.say("Hi there! Feel free to ask me what I see.")
transport.capture_participant_video(participant["id"], framerate=0)
transport.capture_participant_transcription(participant["id"])
image_requester.set_participant_id(participant["id"])
pipeline = Pipeline([
transport.input(),
user_response,
image_requester,
vision_aggregator,
google,
tts,
transport.output()
])
task = PipelineTask(pipeline, allow_interruptions=True)
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
(url, token) = configure()
asyncio.run(main(url, token))

View File

@@ -0,0 +1,113 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import aiohttp
import os
import sys
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.user_response import UserResponseAggregator
from pipecat.processors.aggregators.vision_image_frame import VisionImageFrameAggregator
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.elevenlabs import ElevenLabsTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer
from runner import configure
from loguru import logger
from dotenv import load_dotenv
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class UserImageRequester(FrameProcessor):
def __init__(self, participant_id: str | None = None):
super().__init__()
self._participant_id = participant_id
def set_participant_id(self, participant_id: str):
self._participant_id = participant_id
async def process_frame(self, frame: Frame, direction: FrameDirection):
if self._participant_id and isinstance(frame, TextFrame):
await self.push_frame(UserImageRequestFrame(self._participant_id), FrameDirection.UPSTREAM)
await self.push_frame(frame, direction)
async def main(room_url: str, token):
async with aiohttp.ClientSession() as session:
transport = DailyTransport(
room_url,
token,
"Describe participant video",
DailyParams(
audio_in_enabled=True, # This is so Silero VAD can get audio data
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer()
)
)
tts = ElevenLabsTTSService(
aiohttp_session=session,
api_key=os.getenv("ELEVENLABS_API_KEY"),
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)
user_response = UserResponseAggregator()
image_requester = UserImageRequester()
vision_aggregator = VisionImageFrameAggregator()
gpt4o = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o"
)
tts = ElevenLabsTTSService(
aiohttp_session=session,
api_key=os.getenv("ELEVENLABS_API_KEY"),
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await tts.say("Hi there! Feel free to ask me what I see.")
transport.capture_participant_video(participant["id"], framerate=0)
transport.capture_participant_transcription(participant["id"])
image_requester.set_participant_id(participant["id"])
pipeline = Pipeline([
transport.input(),
user_response,
image_requester,
vision_aggregator,
gpt4o,
tts,
transport.output()
])
task = PipelineTask(pipeline, allow_interruptions=True)
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
(url, token) = configure()
asyncio.run(main(url, token))

View File

@@ -1,32 +1,33 @@
WARNING: --strip-extras is becoming the default in version 8.0.0. To silence this warning, either use --strip-extras to opt into the new default or use --no-strip-extras to retain the existing behavior.
#
# This file is autogenerated by pip-compile with Python 3.10
# This file is autogenerated by pip-compile with Python 3.11
# by the following command:
#
# pip-compile --all-extras pyproject.toml
#
aiohttp==3.9.5
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0
# via pydantic
anthropic==0.25.8
# via pipecat (pyproject.toml)
anthropic==0.25.9
# via pipecat-ai (pyproject.toml)
anyio==4.3.0
# via
# anthropic
# httpx
# openai
async-timeout==4.0.3
# via aiohttp
attrs==23.2.0
# via aiohttp
av==12.0.0
# via faster-whisper
azure-cognitiveservices-speech==1.37.0
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
blinker==1.8.2
# via flask
cachetools==5.3.3
# via google-auth
certifi==2024.2.2
# via
# httpcore
@@ -41,19 +42,17 @@ coloredlogs==15.0.1
ctranslate2==4.2.1
# via faster-whisper
daily-python==0.7.4
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
distro==1.9.0
# via
# anthropic
# openai
einops==0.8.0
# via pipecat (pyproject.toml)
exceptiongroup==1.2.1
# via anyio
# via pipecat-ai (pyproject.toml)
fal-client==0.4.0
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
faster-whisper==1.0.2
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
filelock==3.14.0
# via
# huggingface-hub
@@ -63,25 +62,58 @@ filelock==3.14.0
flask==3.0.3
# via
# flask-cors
# pipecat (pyproject.toml)
# pipecat-ai (pyproject.toml)
flask-cors==4.0.1
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
flatbuffers==24.3.25
# via onnxruntime
frozenlist==1.4.1
# via
# aiohttp
# aiosignal
fsspec==2024.3.1
fsspec==2024.5.0
# via
# huggingface-hub
# torch
google-ai-generativelanguage==0.6.3
# via google-generativeai
google-api-core[grpc]==2.19.0
# via
# google-ai-generativelanguage
# google-api-python-client
# google-generativeai
google-api-python-client==2.129.0
# via google-generativeai
google-auth==2.29.0
# via
# google-ai-generativelanguage
# google-api-core
# google-api-python-client
# google-auth-httplib2
# google-generativeai
google-auth-httplib2==0.2.0
# via google-api-python-client
google-generativeai==0.5.3
# via pipecat-ai (pyproject.toml)
googleapis-common-protos==1.63.0
# via
# google-api-core
# grpcio-status
grpcio==1.63.0
# via pyht
# via
# google-api-core
# grpcio-status
# pyht
grpcio-status==1.62.2
# via google-api-core
h11==0.14.0
# via httpcore
httpcore==1.0.5
# via httpx
httplib2==0.22.0
# via
# google-api-python-client
# google-auth-httplib2
httpx==0.27.0
# via
# anthropic
@@ -110,7 +142,7 @@ jinja2==3.1.4
# flask
# torch
loguru==0.7.2
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
markupsafe==2.1.5
# via
# jinja2
@@ -127,13 +159,13 @@ numpy==1.26.4
# via
# ctranslate2
# onnxruntime
# pipecat (pyproject.toml)
# pipecat-ai (pyproject.toml)
# torchvision
# transformers
onnxruntime==1.17.3
# via faster-whisper
openai==1.26.0
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
packaging==24.0
# via
# huggingface-hub
@@ -141,37 +173,59 @@ packaging==24.0
# transformers
pillow==10.3.0
# via
# pipecat (pyproject.toml)
# pipecat-ai (pyproject.toml)
# torchvision
proto-plus==1.23.0
# via
# google-ai-generativelanguage
# google-api-core
protobuf==4.25.3
# via
# google-ai-generativelanguage
# google-api-core
# google-generativeai
# googleapis-common-protos
# grpcio-status
# onnxruntime
# proto-plus
# pyht
pyasn1==0.6.0
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.4.0
# via google-auth
pyaudio==0.2.14
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
pydantic==2.7.1
# via
# anthropic
# google-generativeai
# openai
pydantic-core==2.18.2
# via pydantic
pyht==0.0.28
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
pyparsing==3.1.2
# via httplib2
python-dotenv==1.0.1
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
pyyaml==6.0.1
# via
# ctranslate2
# huggingface-hub
# timm
# transformers
regex==2024.5.10
regex==2024.5.15
# via transformers
requests==2.31.0
# via
# google-api-core
# huggingface-hub
# pyht
# transformers
rsa==4.9
# via google-auth
safetensors==0.4.3
# via
# timm
@@ -187,7 +241,7 @@ sympy==1.12
# onnxruntime
# torch
timm==0.9.16
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
tokenizers==0.19.1
# via
# anthropic
@@ -195,35 +249,38 @@ tokenizers==0.19.1
# transformers
torch==2.3.0
# via
# pipecat (pyproject.toml)
# pipecat-ai (pyproject.toml)
# timm
# torchaudio
# torchvision
torchaudio==2.3.0
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
torchvision==0.18.0
# via timm
tqdm==4.66.4
# via
# google-generativeai
# huggingface-hub
# openai
# transformers
transformers==4.40.2
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
typing-extensions==4.11.0
# via
# anthropic
# anyio
# google-generativeai
# huggingface-hub
# openai
# pipecat (pyproject.toml)
# pipecat-ai (pyproject.toml)
# pydantic
# pydantic-core
# torch
uritemplate==4.1.1
# via google-api-python-client
urllib3==2.2.1
# via requests
websockets==12.0
# via pipecat (pyproject.toml)
# via pipecat-ai (pyproject.toml)
werkzeug==3.0.3
# via flask
yarl==1.9.4

View File

@@ -37,6 +37,7 @@ azure = [ "azure-cognitiveservices-speech~=1.37.0" ]
daily = [ "daily-python~=0.7.4" ]
examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
fal = [ "fal-client~=0.4.0" ]
google = [ "google-generativeai~=0.5.3" ]
fireworks = [ "openai~=1.26.0" ]
local = [ "pyaudio~=0.2.0" ]
moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ]

View File

@@ -5,10 +5,13 @@
#
from dataclasses import dataclass
import io
from typing import List
from pipecat.frames.frames import Frame
from PIL import Image
from pipecat.frames.frames import Frame, VisionImageRawFrame
from openai._types import NOT_GIVEN, NotGiven
@@ -43,6 +46,31 @@ class OpenAILLMContext:
})
return context
@staticmethod
def from_image_frame(frame: VisionImageRawFrame) -> "OpenAILLMContext":
"""
For images, we are deviating from the OpenAI messages shape. OpenAI
expects images to be base64 encoded, but other vision models may not.
So we'll store the image as bytes and do the base64 encoding as needed
in the LLM service.
"""
context = OpenAILLMContext()
buffer = io.BytesIO()
Image.frombytes(
frame.format,
frame.size,
frame.image
).save(
buffer,
format="JPEG")
context.add_message({
"content": frame.text,
"role": "user",
"data": buffer.getvalue(),
"mime_type": "image/jpeg"
})
return context
def add_message(self, message: ChatCompletionMessageParam):
self.messages.append(message)

View File

@@ -0,0 +1,96 @@
import google.generativeai as gai
import google.ai.generativelanguage as glm
import os
import asyncio
from typing import List
from pipecat.frames.frames import (
Frame,
TextFrame,
VisionImageRawFrame,
LLMMessagesFrame,
LLMResponseStartFrame,
LLMResponseEndFrame)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
from loguru import logger
class GoogleLLMService(LLMService):
"""This class implements inference with Google's AI models
This service translates internally from OpenAILLMContext to the messages format
expected by the Google AI model. We are using the OpenAILLMContext as a lingua
franca for all LLM services, so that it is easy to switch between different LLMs.
"""
def __init__(self, model="gemini-1.5-flash-latest", api_key=None, **kwargs):
super().__init__(**kwargs)
self.model = model
gai.configure(api_key=api_key or os.environ["GOOGLE_API_KEY"])
self.create_client()
def create_client(self):
self._client = gai.GenerativeModel(self.model)
def _get_messages_from_openai_context(
self, context: OpenAILLMContext) -> List[glm.Content]:
openai_messages = context.get_messages()
google_messages = []
for message in openai_messages:
role = message["role"]
content = message["content"]
if role == "system":
role = "user"
elif role == "assistant":
role = "model"
parts = [glm.Part(text=content)]
if "mime_type" in message:
parts.append(
glm.Part(inline_data=glm.Blob(
mime_type=message["mime_type"],
data=message["data"]
)))
google_messages.append({"role": role, "parts": parts})
return google_messages
async def _async_generator_wrapper(self, sync_generator):
for item in sync_generator:
yield item
await asyncio.sleep(0)
async def _process_context(self, context: OpenAILLMContext):
try:
messages = self._get_messages_from_openai_context(context)
await self.push_frame(LLMResponseStartFrame())
response = self._client.generate_content(messages, stream=True)
async for chunk in self._async_generator_wrapper(response):
logger.debug(f"Pushing inference text: {chunk.text}")
await self.push_frame(TextFrame(chunk.text))
await self.push_frame(LLMResponseEndFrame())
except Exception as e:
logger.error(f"Exception: {e}")
async def process_frame(self, frame: Frame, direction: FrameDirection):
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext.from_image_frame(frame)
else:
await self.push_frame(frame, direction)
if context:
await self._process_context(context)

View File

@@ -8,6 +8,7 @@ import io
import json
import time
import aiohttp
import base64
from PIL import Image
@@ -22,7 +23,8 @@ from pipecat.frames.frames import (
LLMResponseEndFrame,
LLMResponseStartFrame,
TextFrame,
URLImageRawFrame
URLImageRawFrame,
VisionImageRawFrame
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection
@@ -67,8 +69,21 @@ class BaseOpenAILLMService(LLMService):
self, context: OpenAILLMContext
) -> AsyncStream[ChatCompletionChunk]:
messages: List[ChatCompletionMessageParam] = context.get_messages()
messages_for_log = json.dumps(messages)
logger.debug(f"Generating chat: {messages_for_log}")
# base64 encode any images
for message in messages:
if message.get("mime_type") == "image/jpeg":
encoded_image = base64.b64encode(message["data"]).decode("utf-8")
text = message["content"]
message["content"] = [
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
]
del message["data"]
del message["mime_type"]
# messages_for_log = json.dumps(messages)
# logger.debug(f"Generating chat: {messages_for_log}")
start_time = time.time()
chunks: AsyncStream[ChatCompletionChunk] = (
@@ -151,6 +166,8 @@ class BaseOpenAILLMService(LLMService):
context: OpenAILLMContext = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext.from_image_frame(frame)
else:
await self.push_frame(frame, direction)