GoogleLLMService: deprecate google-generativeai

This commit is contained in:
Aleix Conchillo Flaqué
2025-05-08 14:07:46 -07:00
parent 9643296e29
commit f31efa42c9
6 changed files with 165 additions and 177 deletions

View File

@@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Changed
- `GoogleLLMService` has been updated to use `google-genai` instead of the
deprecated `google-generativeai`.
## [0.0.67] - 2025-05-07
### Added

View File

@@ -11,18 +11,17 @@ from pathlib import Path
from dotenv import load_dotenv
from loguru import logger
from openai import audio
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import Frame
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.google.llm import GoogleLLMService, LLMSearchResponseFrame
from pipecat.services.llm_service import LLMService
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
@@ -33,7 +32,7 @@ load_dotenv(override=True)
# Function handlers for the LLM
search_tool = {"google_search_retrieval": {}}
search_tool = {"google_search": {}}
tools = [search_tool]
system_instruction = """
@@ -50,14 +49,22 @@ Start each interaction by asking the user about which place they would like to k
"""
class LLMSearchLoggerProcessor(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
class LLMSearchLoggerObserver(BaseObserver):
async def on_push_frame(self, data: FramePushed):
src = data.source
dst = data.destination
frame = data.frame
timestamp = data.timestamp
if not isinstance(src, LLMService) and not isinstance(dst, LLMService):
return
time_sec = timestamp / 1_000_000_000
arrow = ""
if isinstance(frame, LLMSearchResponseFrame):
print(f"LLMSearchLoggerProcessor: {frame}")
await self.push_frame(frame)
logger.debug(f"🧠 {arrow} {dst} LLM SEARCH RESPONSE FRAME: {frame} at {time_sec:.2f}s")
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
@@ -84,7 +91,6 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
api_key=os.getenv("GOOGLE_API_KEY"),
system_instruction=system_instruction,
tools=tools,
model="gemini-1.5-flash-002",
)
context = OpenAILLMContext(
@@ -97,22 +103,23 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
)
context_aggregator = llm.create_context_aggregator(context)
llm_search_logger = LLMSearchLoggerProcessor()
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
llm_search_logger,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=True),
observers=[LLMSearchLoggerObserver()],
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):

View File

@@ -102,9 +102,9 @@ async def main():
llm = GoogleLLMService(
api_key=os.getenv("GOOGLE_API_KEY"),
model="gemini-1.5-flash-002",
system_instruction=system_instruction,
tools=tools,
model="gemini-1.5-flash",
)
context = OpenAILLMContext(

View File

@@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ]
fireworks = []
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.14.0", "websockets~=13.1" ]
grok = []
groq = [ "groq~=0.23.0" ]
gstreamer = [ "pygobject~=3.50.0" ]

View File

@@ -52,10 +52,16 @@ from pipecat.services.openai.llm import (
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
try:
import google.ai.generativelanguage as glm
import google.generativeai as gai
from google import genai
from google.api_core.exceptions import DeadlineExceeded
from google.generativeai.types import GenerationConfig
from google.genai.types import (
Blob,
Content,
FunctionCall,
FunctionResponse,
GenerateContentConfig,
Part,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
@@ -65,9 +71,7 @@ except ModuleNotFoundError as e:
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message(
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
)
self._context.add_message(Content(role="user", parts=[Part(text=self._aggregation)]))
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
@@ -83,15 +87,15 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def handle_aggregation(self, aggregation: str):
self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)]))
self._context.add_message(Content(role="model", parts=[Part(text=aggregation)]))
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
self._context.add_message(
glm.Content(
Content(
role="model",
parts=[
glm.Part(
function_call=glm.FunctionCall(
Part(
function_call=FunctionCall(
id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
)
)
@@ -99,11 +103,11 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
)
)
self._context.add_message(
glm.Content(
Content(
role="user",
parts=[
glm.Part(
function_response=glm.FunctionResponse(
Part(
function_response=FunctionResponse(
id=frame.tool_call_id,
name=frame.function_name,
response={"response": "IN_PROGRESS"},
@@ -187,7 +191,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Convert each message individually
converted_messages = []
for msg in messages:
if isinstance(msg, glm.Content):
if isinstance(msg, Content):
# Already in Gemini format
converted_messages.append(msg)
else:
@@ -202,7 +206,7 @@ class GoogleLLMContext(OpenAILLMContext):
def get_messages_for_logging(self):
msgs = []
for message in self.messages:
obj = glm.Content.to_dict(message)
obj = message.to_json_dict()
try:
if "parts" in obj:
for part in obj["parts"]:
@@ -221,10 +225,10 @@ class GoogleLLMContext(OpenAILLMContext):
parts = []
if text:
parts.append(glm.Part(text=text))
parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())))
parts.append(Part(text=text))
parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue())))
self.add_message(glm.Content(role="user", parts=parts))
self.add_message(Content(role="user", parts=parts))
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
@@ -239,10 +243,10 @@ class GoogleLLMContext(OpenAILLMContext):
data = b"".join(frame.audio for frame in audio_frames)
# NOTE(aleix): According to the docs only text or inline_data should be needed.
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
parts.append(glm.Part(text=text))
parts.append(Part(text=text))
parts.append(
glm.Part(
inline_data=glm.Blob(
Part(
inline_data=Blob(
mime_type="audio/wav",
data=(
bytes(
@@ -252,7 +256,7 @@ class GoogleLLMContext(OpenAILLMContext):
)
),
)
self.add_message(glm.Content(role="user", parts=parts))
self.add_message(Content(role="user", parts=parts))
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
# self.add_message(message)
@@ -271,7 +275,7 @@ class GoogleLLMContext(OpenAILLMContext):
}
Returns:
glm.Content object with:
Content object with:
- role: "user" or "model" (converted from "assistant")
- parts: List[Part] containing text, inline_data, or function calls
Returns None for system messages.
@@ -288,8 +292,8 @@ class GoogleLLMContext(OpenAILLMContext):
if message.get("tool_calls"):
for tc in message["tool_calls"]:
parts.append(
glm.Part(
function_call=glm.FunctionCall(
Part(
function_call=FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
@@ -298,30 +302,30 @@ class GoogleLLMContext(OpenAILLMContext):
elif role == "tool":
role = "model"
parts.append(
glm.Part(
function_response=glm.FunctionResponse(
Part(
function_response=FunctionResponse(
name="tool_call_result", # seems to work to hard-code the same name every time
response=json.loads(message["content"]),
)
)
)
elif isinstance(content, str):
parts.append(glm.Part(text=content))
parts.append(Part(text=content))
elif isinstance(content, list):
for c in content:
if c["type"] == "text":
parts.append(glm.Part(text=c["text"]))
parts.append(Part(text=c["text"]))
elif c["type"] == "image_url":
parts.append(
glm.Part(
inline_data=glm.Blob(
Part(
inline_data=Blob(
mime_type="image/jpeg",
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
)
)
)
message = glm.Content(role=role, parts=parts)
message = Content(role=role, parts=parts)
return message
def to_standard_messages(self, obj) -> list:
@@ -409,7 +413,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Process each message, preserving Google-formatted messages and converting others
for message in self._messages:
if isinstance(message, glm.Content):
if isinstance(message, Content):
# Keep existing Google-formatted messages (e.g., function calls/responses)
converted_messages.append(message)
continue
@@ -433,9 +437,7 @@ class GoogleLLMContext(OpenAILLMContext):
# Add system message back as a user message if we only have function messages
if self.system_message and not has_regular_messages:
self._messages.append(
glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
)
self._messages.append(Content(role="user", parts=[Part(text=self.system_message)]))
# Remove any empty messages
self._messages = [m for m in self._messages if m.parts]
@@ -463,7 +465,7 @@ class GoogleLLMService(LLMService):
self,
*,
api_key: str,
model: str = "gemini-2.0-flash-001",
model: str = "gemini-2.0-flash",
params: InputParams = InputParams(),
system_instruction: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
@@ -471,10 +473,10 @@ class GoogleLLMService(LLMService):
**kwargs,
):
super().__init__(**kwargs)
gai.configure(api_key=api_key)
self.set_model_name(model)
self._api_key = api_key
self._system_instruction = system_instruction
self._create_client()
self._create_client(api_key)
self._settings = {
"max_tokens": params.max_tokens,
"temperature": params.temperature,
@@ -488,10 +490,8 @@ class GoogleLLMService(LLMService):
def can_generate_metrics(self) -> bool:
return True
def _create_client(self):
self._client = gai.GenerativeModel(
self._model_name, system_instruction=self._system_instruction
)
def _create_client(self, api_key: str):
self._client = genai.Client(api_key=api_key)
async def _process_context(self, context: OpenAILLMContext):
await self.push_frame(LLMFullResponseStartFrame())
@@ -513,23 +513,7 @@ class GoogleLLMService(LLMService):
if context.system_message and self._system_instruction != context.system_message:
logger.debug(f"System instruction changed: {context.system_message}")
self._system_instruction = context.system_message
self._create_client()
# Filter out None values and create GenerationConfig
generation_params = {
k: v
for k, v in {
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"top_k": self._settings["top_k"],
"max_output_tokens": self._settings["max_tokens"],
}.items()
if v is not None
}
generation_config = GenerationConfig(**generation_params) if generation_params else None
await self.start_ttfb_metrics()
tools = []
if context.tools:
tools = context.tools
@@ -538,112 +522,104 @@ class GoogleLLMService(LLMService):
tool_config = None
if self._tool_config:
tool_config = self._tool_config
response = await self._client.generate_content_async(
# Filter out None values and create GenerationContentConfig
generation_params = {
k: v
for k, v in {
"system_instruction": self._system_instruction,
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"top_k": self._settings["top_k"],
"max_output_tokens": self._settings["max_tokens"],
"tools": tools,
"tool_config": tool_config,
}.items()
if v is not None
}
generation_config = (
GenerateContentConfig(**generation_params) if generation_params else None
)
await self.start_ttfb_metrics()
response = await self._client.aio.models.generate_content_stream(
model=self._model_name,
contents=messages,
tools=tools,
stream=True,
generation_config=generation_config,
tool_config=tool_config,
config=generation_config,
)
await self.stop_ttfb_metrics()
if response.usage_metadata:
# Use only the prompt token count from the response object
prompt_tokens = response.usage_metadata.prompt_token_count
total_tokens = prompt_tokens
async for chunk in response:
if chunk.usage_metadata:
# Use only the completion_tokens from the chunks. Prompt tokens are already counted and
# are repeated here.
completion_tokens += chunk.usage_metadata.candidates_token_count
total_tokens += chunk.usage_metadata.candidates_token_count
try:
for c in chunk.parts:
if c.text:
search_result += c.text
await self.push_frame(LLMTextFrame(c.text))
elif c.function_call:
logger.debug(f"Function call: {c.function_call}")
args = type(c.function_call).to_dict(c.function_call).get("args", {})
await self.call_function(
context=context,
tool_call_id=str(uuid.uuid4()),
function_name=c.function_call.name,
arguments=args,
)
# Handle grounding metadata
# It seems only the last chunk that we receive may contain this information
# If the response doesn't include groundingMetadata, this means the response wasn't grounded.
if chunk.candidates:
for candidate in chunk.candidates:
# logger.debug(f"candidate received: {candidate}")
# Extract grounding metadata
grounding_metadata = (
{
"rendered_content": getattr(
getattr(candidate, "grounding_metadata", None),
"search_entry_point",
None,
).rendered_content
if hasattr(
getattr(candidate, "grounding_metadata", None),
"search_entry_point",
)
else None,
"origins": [
{
"site_uri": getattr(grounding_chunk.web, "uri", None),
"site_title": getattr(
grounding_chunk.web, "title", None
),
"results": [
{
"text": getattr(
grounding_support.segment, "text", ""
),
"confidence": getattr(
grounding_support, "confidence_scores", None
),
}
for grounding_support in getattr(
getattr(candidate, "grounding_metadata", None),
"grounding_supports",
[],
)
if index
in getattr(
grounding_support, "grounding_chunk_indices", []
)
],
}
for index, grounding_chunk in enumerate(
getattr(
getattr(candidate, "grounding_metadata", None),
"grounding_chunks",
[],
)
)
],
}
if getattr(candidate, "grounding_metadata", None)
else None
)
except Exception as e:
# Google LLMs seem to flag safety issues a lot!
if chunk.candidates[0].finish_reason == 3:
logger.debug(
f"LLM refused to generate content for safety reasons - {messages}."
)
else:
logger.exception(f"{self} error: {e}")
prompt_tokens += chunk.usage_metadata.prompt_token_count or 0
completion_tokens += chunk.usage_metadata.candidates_token_count or 0
total_tokens += chunk.usage_metadata.total_token_count or 0
if not chunk.candidates:
continue
for candidate in chunk.candidates:
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if not part.thought and part.text:
search_result += part.text
await self.push_frame(LLMTextFrame(part.text))
elif part.function_call:
function_call = part.function_call
id = function_call.id or str(uuid.uuid4())
logger.debug(f"Function call: {function_call.name}:{id}")
await self.call_function(
context=context,
tool_call_id=id,
function_name=function_call.name,
arguments=function_call.args or {},
)
if (
candidate.grounding_metadata
and candidate.grounding_metadata.grounding_chunks
):
m = candidate.grounding_metadata
rendered_content = (
m.search_entry_point.rendered_content if m.search_entry_point else None
)
origins = [
{
"site_uri": grounding_chunk.web.uri
if grounding_chunk.web
else None,
"site_title": grounding_chunk.web.title
if grounding_chunk.web
else None,
"results": [
{
"text": grounding_support.segment.text
if grounding_support.segment
else "",
"confidence": grounding_support.confidence_scores,
}
for grounding_support in (
m.grounding_supports if m.grounding_supports else []
)
if grounding_support.grounding_chunk_indices
and index in grounding_support.grounding_chunk_indices
],
}
for index, grounding_chunk in enumerate(
m.grounding_chunks if m.grounding_chunks else []
)
]
grounding_metadata = {
"rendered_content": rendered_content,
"origins": origins,
}
except DeadlineExceeded:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
if grounding_metadata is not None and isinstance(grounding_metadata, dict):
if grounding_metadata and isinstance(grounding_metadata, dict):
llm_search_frame = LLMSearchResponseFrame(
search_result=search_result,
origins=grounding_metadata["origins"],

View File

@@ -8,8 +8,6 @@ import json
import unittest
from typing import Any
import google.ai.generativelanguage as glm
from pipecat.frames.frames import (
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
@@ -758,13 +756,13 @@ class TestGoogleUserContextAggregator(
AGGREGATOR_CLASS = GoogleUserContextAggregator
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
@@ -776,17 +774,17 @@ class TestGoogleAssistantContextAggregator(
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["text"] == content
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
obj = glm.Content.to_dict(context.messages[index])
obj = context.messages[index].to_json_dict()
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)