Compare commits
2 Commits
v0.0.44
...
khk-gemini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05f44bf4c3 | ||
|
|
6d8dc732e1 |
110
examples/foundational/12a-describe-video-gemini-flash.py
Normal file
110
examples/foundational/12a-describe-video-gemini-flash.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.user_response import UserResponseAggregator
|
||||
from pipecat.processors.aggregators.vision_image_frame import VisionImageFrameAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.google import GoogleLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class UserImageRequester(FrameProcessor):
|
||||
|
||||
def __init__(self, participant_id: str | None = None):
|
||||
super().__init__()
|
||||
self._participant_id = participant_id
|
||||
|
||||
def set_participant_id(self, participant_id: str):
|
||||
self._participant_id = participant_id
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if self._participant_id and isinstance(frame, TextFrame):
|
||||
await self.push_frame(UserImageRequestFrame(self._participant_id), FrameDirection.UPSTREAM)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Describe participant video",
|
||||
DailyParams(
|
||||
audio_in_enabled=True, # This is so Silero VAD can get audio data
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer()
|
||||
)
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
|
||||
)
|
||||
|
||||
user_response = UserResponseAggregator()
|
||||
|
||||
image_requester = UserImageRequester()
|
||||
|
||||
vision_aggregator = VisionImageFrameAggregator()
|
||||
|
||||
google = GoogleLLMService(model="gemini-1.5-flash-latest")
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await tts.say("Hi there! Feel free to ask me what I see.")
|
||||
transport.capture_participant_video(participant["id"], framerate=0)
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
image_requester.set_participant_id(participant["id"])
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
user_response,
|
||||
image_requester,
|
||||
vision_aggregator,
|
||||
google,
|
||||
tts,
|
||||
transport.output()
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, allow_interruptions=True)
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
113
examples/foundational/12b-describe-video-gpt-4o.py
Normal file
113
examples/foundational/12b-describe-video-gpt-4o.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.user_response import UserResponseAggregator
|
||||
from pipecat.processors.aggregators.vision_image_frame import VisionImageFrameAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class UserImageRequester(FrameProcessor):
|
||||
|
||||
def __init__(self, participant_id: str | None = None):
|
||||
super().__init__()
|
||||
self._participant_id = participant_id
|
||||
|
||||
def set_participant_id(self, participant_id: str):
|
||||
self._participant_id = participant_id
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if self._participant_id and isinstance(frame, TextFrame):
|
||||
await self.push_frame(UserImageRequestFrame(self._participant_id), FrameDirection.UPSTREAM)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Describe participant video",
|
||||
DailyParams(
|
||||
audio_in_enabled=True, # This is so Silero VAD can get audio data
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer()
|
||||
)
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
|
||||
)
|
||||
|
||||
user_response = UserResponseAggregator()
|
||||
|
||||
image_requester = UserImageRequester()
|
||||
|
||||
vision_aggregator = VisionImageFrameAggregator()
|
||||
|
||||
gpt4o = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o"
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await tts.say("Hi there! Feel free to ask me what I see.")
|
||||
transport.capture_participant_video(participant["id"], framerate=0)
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
image_requester.set_participant_id(participant["id"])
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
user_response,
|
||||
image_requester,
|
||||
vision_aggregator,
|
||||
gpt4o,
|
||||
tts,
|
||||
transport.output()
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, allow_interruptions=True)
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
@@ -1,32 +1,33 @@
|
||||
WARNING: --strip-extras is becoming the default in version 8.0.0. To silence this warning, either use --strip-extras to opt into the new default or use --no-strip-extras to retain the existing behavior.
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# This file is autogenerated by pip-compile with Python 3.11
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --all-extras pyproject.toml
|
||||
#
|
||||
aiohttp==3.9.5
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.6.0
|
||||
# via pydantic
|
||||
anthropic==0.25.8
|
||||
# via pipecat (pyproject.toml)
|
||||
anthropic==0.25.9
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
anyio==4.3.0
|
||||
# via
|
||||
# anthropic
|
||||
# httpx
|
||||
# openai
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
attrs==23.2.0
|
||||
# via aiohttp
|
||||
av==12.0.0
|
||||
# via faster-whisper
|
||||
azure-cognitiveservices-speech==1.37.0
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
blinker==1.8.2
|
||||
# via flask
|
||||
cachetools==5.3.3
|
||||
# via google-auth
|
||||
certifi==2024.2.2
|
||||
# via
|
||||
# httpcore
|
||||
@@ -41,19 +42,17 @@ coloredlogs==15.0.1
|
||||
ctranslate2==4.2.1
|
||||
# via faster-whisper
|
||||
daily-python==0.7.4
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
distro==1.9.0
|
||||
# via
|
||||
# anthropic
|
||||
# openai
|
||||
einops==0.8.0
|
||||
# via pipecat (pyproject.toml)
|
||||
exceptiongroup==1.2.1
|
||||
# via anyio
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
fal-client==0.4.0
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
faster-whisper==1.0.2
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
filelock==3.14.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
@@ -63,25 +62,58 @@ filelock==3.14.0
|
||||
flask==3.0.3
|
||||
# via
|
||||
# flask-cors
|
||||
# pipecat (pyproject.toml)
|
||||
# pipecat-ai (pyproject.toml)
|
||||
flask-cors==4.0.1
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
flatbuffers==24.3.25
|
||||
# via onnxruntime
|
||||
frozenlist==1.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2024.3.1
|
||||
fsspec==2024.5.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
google-ai-generativelanguage==0.6.3
|
||||
# via google-generativeai
|
||||
google-api-core[grpc]==2.19.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-python-client
|
||||
# google-generativeai
|
||||
google-api-python-client==2.129.0
|
||||
# via google-generativeai
|
||||
google-auth==2.29.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
# google-api-python-client
|
||||
# google-auth-httplib2
|
||||
# google-generativeai
|
||||
google-auth-httplib2==0.2.0
|
||||
# via google-api-python-client
|
||||
google-generativeai==0.5.3
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
googleapis-common-protos==1.63.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
grpcio==1.63.0
|
||||
# via pyht
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
# pyht
|
||||
grpcio-status==1.62.2
|
||||
# via google-api-core
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
httpcore==1.0.5
|
||||
# via httpx
|
||||
httplib2==0.22.0
|
||||
# via
|
||||
# google-api-python-client
|
||||
# google-auth-httplib2
|
||||
httpx==0.27.0
|
||||
# via
|
||||
# anthropic
|
||||
@@ -110,7 +142,7 @@ jinja2==3.1.4
|
||||
# flask
|
||||
# torch
|
||||
loguru==0.7.2
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
markupsafe==2.1.5
|
||||
# via
|
||||
# jinja2
|
||||
@@ -127,13 +159,13 @@ numpy==1.26.4
|
||||
# via
|
||||
# ctranslate2
|
||||
# onnxruntime
|
||||
# pipecat (pyproject.toml)
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# torchvision
|
||||
# transformers
|
||||
onnxruntime==1.17.3
|
||||
# via faster-whisper
|
||||
openai==1.26.0
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
packaging==24.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
@@ -141,37 +173,59 @@ packaging==24.0
|
||||
# transformers
|
||||
pillow==10.3.0
|
||||
# via
|
||||
# pipecat (pyproject.toml)
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# torchvision
|
||||
proto-plus==1.23.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
protobuf==4.25.3
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
# google-generativeai
|
||||
# googleapis-common-protos
|
||||
# grpcio-status
|
||||
# onnxruntime
|
||||
# proto-plus
|
||||
# pyht
|
||||
pyasn1==0.6.0
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
pyasn1-modules==0.4.0
|
||||
# via google-auth
|
||||
pyaudio==0.2.14
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pydantic==2.7.1
|
||||
# via
|
||||
# anthropic
|
||||
# google-generativeai
|
||||
# openai
|
||||
pydantic-core==2.18.2
|
||||
# via pydantic
|
||||
pyht==0.0.28
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pyparsing==3.1.2
|
||||
# via httplib2
|
||||
python-dotenv==1.0.1
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pyyaml==6.0.1
|
||||
# via
|
||||
# ctranslate2
|
||||
# huggingface-hub
|
||||
# timm
|
||||
# transformers
|
||||
regex==2024.5.10
|
||||
regex==2024.5.15
|
||||
# via transformers
|
||||
requests==2.31.0
|
||||
# via
|
||||
# google-api-core
|
||||
# huggingface-hub
|
||||
# pyht
|
||||
# transformers
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
safetensors==0.4.3
|
||||
# via
|
||||
# timm
|
||||
@@ -187,7 +241,7 @@ sympy==1.12
|
||||
# onnxruntime
|
||||
# torch
|
||||
timm==0.9.16
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
tokenizers==0.19.1
|
||||
# via
|
||||
# anthropic
|
||||
@@ -195,35 +249,38 @@ tokenizers==0.19.1
|
||||
# transformers
|
||||
torch==2.3.0
|
||||
# via
|
||||
# pipecat (pyproject.toml)
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# timm
|
||||
# torchaudio
|
||||
# torchvision
|
||||
torchaudio==2.3.0
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
torchvision==0.18.0
|
||||
# via timm
|
||||
tqdm==4.66.4
|
||||
# via
|
||||
# google-generativeai
|
||||
# huggingface-hub
|
||||
# openai
|
||||
# transformers
|
||||
transformers==4.40.2
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
typing-extensions==4.11.0
|
||||
# via
|
||||
# anthropic
|
||||
# anyio
|
||||
# google-generativeai
|
||||
# huggingface-hub
|
||||
# openai
|
||||
# pipecat (pyproject.toml)
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# torch
|
||||
uritemplate==4.1.1
|
||||
# via google-api-python-client
|
||||
urllib3==2.2.1
|
||||
# via requests
|
||||
websockets==12.0
|
||||
# via pipecat (pyproject.toml)
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
werkzeug==3.0.3
|
||||
# via flask
|
||||
yarl==1.9.4
|
||||
|
||||
@@ -37,6 +37,7 @@ azure = [ "azure-cognitiveservices-speech~=1.37.0" ]
|
||||
daily = [ "daily-python~=0.7.4" ]
|
||||
examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
|
||||
fal = [ "fal-client~=0.4.0" ]
|
||||
google = [ "google-generativeai~=0.5.3" ]
|
||||
fireworks = [ "openai~=1.26.0" ]
|
||||
local = [ "pyaudio~=0.2.0" ]
|
||||
moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ]
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
import io
|
||||
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.frames.frames import Frame, VisionImageRawFrame
|
||||
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
|
||||
@@ -43,6 +46,31 @@ class OpenAILLMContext:
|
||||
})
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def from_image_frame(frame: VisionImageRawFrame) -> "OpenAILLMContext":
|
||||
"""
|
||||
For images, we are deviating from the OpenAI messages shape. OpenAI
|
||||
expects images to be base64 encoded, but other vision models may not.
|
||||
So we'll store the image as bytes and do the base64 encoding as needed
|
||||
in the LLM service.
|
||||
"""
|
||||
context = OpenAILLMContext()
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(
|
||||
frame.format,
|
||||
frame.size,
|
||||
frame.image
|
||||
).save(
|
||||
buffer,
|
||||
format="JPEG")
|
||||
context.add_message({
|
||||
"content": frame.text,
|
||||
"role": "user",
|
||||
"data": buffer.getvalue(),
|
||||
"mime_type": "image/jpeg"
|
||||
})
|
||||
return context
|
||||
|
||||
def add_message(self, message: ChatCompletionMessageParam):
|
||||
self.messages.append(message)
|
||||
|
||||
|
||||
96
src/pipecat/services/google.py
Normal file
96
src/pipecat/services/google.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import google.generativeai as gai
|
||||
import google.ai.generativelanguage as glm
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TextFrame,
|
||||
VisionImageRawFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMResponseStartFrame,
|
||||
LLMResponseEndFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GoogleLLMService(LLMService):
|
||||
"""This class implements inference with Google's AI models
|
||||
|
||||
This service translates internally from OpenAILLMContext to the messages format
|
||||
expected by the Google AI model. We are using the OpenAILLMContext as a lingua
|
||||
franca for all LLM services, so that it is easy to switch between different LLMs.
|
||||
"""
|
||||
|
||||
def __init__(self, model="gemini-1.5-flash-latest", api_key=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
gai.configure(api_key=api_key or os.environ["GOOGLE_API_KEY"])
|
||||
self.create_client()
|
||||
|
||||
def create_client(self):
|
||||
self._client = gai.GenerativeModel(self.model)
|
||||
|
||||
def _get_messages_from_openai_context(
|
||||
self, context: OpenAILLMContext) -> List[glm.Content]:
|
||||
openai_messages = context.get_messages()
|
||||
google_messages = []
|
||||
|
||||
for message in openai_messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
if role == "system":
|
||||
role = "user"
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = [glm.Part(text=content)]
|
||||
if "mime_type" in message:
|
||||
parts.append(
|
||||
glm.Part(inline_data=glm.Blob(
|
||||
mime_type=message["mime_type"],
|
||||
data=message["data"]
|
||||
)))
|
||||
google_messages.append({"role": role, "parts": parts})
|
||||
|
||||
return google_messages
|
||||
|
||||
async def _async_generator_wrapper(self, sync_generator):
|
||||
for item in sync_generator:
|
||||
yield item
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
try:
|
||||
messages = self._get_messages_from_openai_context(context)
|
||||
|
||||
await self.push_frame(LLMResponseStartFrame())
|
||||
response = self._client.generate_content(messages, stream=True)
|
||||
|
||||
async for chunk in self._async_generator_wrapper(response):
|
||||
logger.debug(f"Pushing inference text: {chunk.text}")
|
||||
await self.push_frame(TextFrame(chunk.text))
|
||||
|
||||
await self.push_frame(LLMResponseEndFrame())
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
context = None
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: OpenAILLMContext = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
@@ -8,6 +8,7 @@ import io
|
||||
import json
|
||||
import time
|
||||
import aiohttp
|
||||
import base64
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -22,7 +23,8 @@ from pipecat.frames.frames import (
|
||||
LLMResponseEndFrame,
|
||||
LLMResponseStartFrame,
|
||||
TextFrame,
|
||||
URLImageRawFrame
|
||||
URLImageRawFrame,
|
||||
VisionImageRawFrame
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -67,8 +69,21 @@ class BaseOpenAILLMService(LLMService):
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
messages_for_log = json.dumps(messages)
|
||||
logger.debug(f"Generating chat: {messages_for_log}")
|
||||
|
||||
# base64 encode any images
|
||||
for message in messages:
|
||||
if message.get("mime_type") == "image/jpeg":
|
||||
encoded_image = base64.b64encode(message["data"]).decode("utf-8")
|
||||
text = message["content"]
|
||||
message["content"] = [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
|
||||
]
|
||||
del message["data"]
|
||||
del message["mime_type"]
|
||||
|
||||
# messages_for_log = json.dumps(messages)
|
||||
# logger.debug(f"Generating chat: {messages_for_log}")
|
||||
|
||||
start_time = time.time()
|
||||
chunks: AsyncStream[ChatCompletionChunk] = (
|
||||
@@ -151,6 +166,8 @@ class BaseOpenAILLMService(LLMService):
|
||||
context: OpenAILLMContext = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user