diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b85c7f87..e2ca4dce1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed + +- `GoogleLLMService` has been updated to use `google-genai` instead of the + deprecated `google-generativeai`. + ## [0.0.67] - 2025-05-07 ### Added diff --git a/examples/foundational/32-gemini-grounding-metadata.py b/examples/foundational/32-gemini-grounding-metadata.py index 0c9f4cf38..8c53f7367 100644 --- a/examples/foundational/32-gemini-grounding-metadata.py +++ b/examples/foundational/32-gemini-grounding-metadata.py @@ -11,18 +11,17 @@ from pathlib import Path from dotenv import load_dotenv from loguru import logger -from openai import audio from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.frames.frames import Frame +from pipecat.observers.base_observer import BaseObserver, FramePushed from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.services.cartesia.tts import CartesiaTTSService from pipecat.services.deepgram.stt import DeepgramSTTService from pipecat.services.google.llm import GoogleLLMService, LLMSearchResponseFrame +from pipecat.services.llm_service import LLMService from pipecat.transports.base_transport import TransportParams from pipecat.transports.network.small_webrtc import SmallWebRTCTransport from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection @@ -33,7 +32,7 @@ load_dotenv(override=True) # Function handlers for the LLM -search_tool = {"google_search_retrieval": {}} +search_tool = {"google_search": {}} tools = [search_tool] system_instruction = """ @@ -50,14 +49,22 @@ Start each interaction by asking the user about which place they would like to k """ -class LLMSearchLoggerProcessor(FrameProcessor): - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) +class LLMSearchLoggerObserver(BaseObserver): + async def on_push_frame(self, data: FramePushed): + src = data.source + dst = data.destination + frame = data.frame + timestamp = data.timestamp + + if not isinstance(src, LLMService) and not isinstance(dst, LLMService): + return + + time_sec = timestamp / 1_000_000_000 + + arrow = "→" if isinstance(frame, LLMSearchResponseFrame): - print(f"LLMSearchLoggerProcessor: {frame}") - - await self.push_frame(frame) + logger.debug(f"🧠 {arrow} {dst} LLM SEARCH RESPONSE FRAME: {frame} at {time_sec:.2f}s") async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace): @@ -84,7 +91,6 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac api_key=os.getenv("GOOGLE_API_KEY"), system_instruction=system_instruction, tools=tools, - model="gemini-1.5-flash-002", ) context = OpenAILLMContext( @@ -97,22 +103,23 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac ) context_aggregator = llm.create_context_aggregator(context) - llm_search_logger = LLMSearchLoggerProcessor() - pipeline = Pipeline( [ transport.input(), stt, context_aggregator.user(), llm, - llm_search_logger, tts, transport.output(), context_aggregator.assistant(), ] ) - task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True)) + task = PipelineTask( + pipeline, + params=PipelineParams(allow_interruptions=True), + observers=[LLMSearchLoggerObserver()], + ) @transport.event_handler("on_client_connected") async def on_client_connected(transport, client): diff --git a/examples/news-chatbot/server/news_bot.py b/examples/news-chatbot/server/news_bot.py index 80355b43c..c78752dfb 100644 --- a/examples/news-chatbot/server/news_bot.py +++ b/examples/news-chatbot/server/news_bot.py @@ -102,9 +102,9 @@ async def main(): llm = GoogleLLMService( api_key=os.getenv("GOOGLE_API_KEY"), - model="gemini-1.5-flash-002", system_instruction=system_instruction, tools=tools, + model="gemini-1.5-flash", ) context = OpenAILLMContext( @@ -153,7 +153,6 @@ async def main(): @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): logger.debug("First participant joined: {}", participant["id"]) - await transport.capture_participant_transcription(participant["id"]) @transport.event_handler("on_participant_left") async def on_participant_left(transport, participant, reason): diff --git a/examples/simple-chatbot/server/bot-gemini.py b/examples/simple-chatbot/server/bot-gemini.py index 70dfccf2d..a38dc11d3 100644 --- a/examples/simple-chatbot/server/bot-gemini.py +++ b/examples/simple-chatbot/server/bot-gemini.py @@ -187,7 +187,7 @@ async def main(): @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): - await transport.capture_participant_transcription(participant["id"]) + print(f"Participant joined: {participant}") @transport.event_handler("on_participant_left") async def on_participant_left(transport, participant, reason): diff --git a/examples/simple-chatbot/server/bot-openai.py b/examples/simple-chatbot/server/bot-openai.py index 07c56aa28..63226396e 100644 --- a/examples/simple-chatbot/server/bot-openai.py +++ b/examples/simple-chatbot/server/bot-openai.py @@ -215,6 +215,7 @@ async def main(): @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): + print(f"Participant joined: {participant}") await transport.capture_participant_transcription(participant["id"]) @transport.event_handler("on_participant_left") diff --git a/examples/word-wrangler-gemini-live/server/bot_phone_local.py b/examples/word-wrangler-gemini-live/server/bot_phone_local.py index 7c86a6895..69a23623c 100644 --- a/examples/word-wrangler-gemini-live/server/bot_phone_local.py +++ b/examples/word-wrangler-gemini-live/server/bot_phone_local.py @@ -30,7 +30,7 @@ from loguru import logger from pipecatcloud.agent import DailySessionArguments from word_list import generate_game_words -from pipecat.audio.resamplers.soxr_resampler import SOXRAudioResampler +from pipecat.audio.utils import create_default_resampler from pipecat.audio.vad.silero import SileroVADAnalyzer from pipecat.frames.frames import ( BotStoppedSpeakingFrame, @@ -524,7 +524,7 @@ async def tts_audio_raw_frame_filter(frame: Frame): # Create a resampler instance once -resampler = SOXRAudioResampler() +resampler = create_default_resampler() async def tts_to_input_audio_transformer(frame: Frame): @@ -689,8 +689,6 @@ Important guidelines: @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): logger.info("First participant joined: {}", participant["id"]) - # Capture the participant's transcription - await transport.capture_participant_transcription(participant["id"]) # Kick off the conversation await task.queue_frames([context_aggregator.user().get_context_frame()]) # Start the game timer diff --git a/pyproject.toml b/pyproject.toml index 3b34569c0..b37314496 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ] fireworks = [] fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ] gladia = [ "websockets~=13.1" ] -google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ] +google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.14.0", "websockets~=13.1" ] grok = [] groq = [ "groq~=0.23.0" ] gstreamer = [ "pygobject~=3.50.0" ] diff --git a/src/pipecat/services/google/llm.py b/src/pipecat/services/google/llm.py index bf9714817..50f1ad680 100644 --- a/src/pipecat/services/google/llm.py +++ b/src/pipecat/services/google/llm.py @@ -52,10 +52,16 @@ from pipecat.services.openai.llm import ( os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" try: - import google.ai.generativelanguage as glm - import google.generativeai as gai + from google import genai from google.api_core.exceptions import DeadlineExceeded - from google.generativeai.types import GenerationConfig + from google.genai.types import ( + Blob, + Content, + FunctionCall, + FunctionResponse, + GenerateContentConfig, + Part, + ) except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.") @@ -65,9 +71,7 @@ except ModuleNotFoundError as e: class GoogleUserContextAggregator(OpenAIUserContextAggregator): async def push_aggregation(self): if len(self._aggregation) > 0: - self._context.add_message( - glm.Content(role="user", parts=[glm.Part(text=self._aggregation)]) - ) + self._context.add_message(Content(role="user", parts=[Part(text=self._aggregation)])) # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. @@ -83,15 +87,15 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator): class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): async def handle_aggregation(self, aggregation: str): - self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)])) + self._context.add_message(Content(role="model", parts=[Part(text=aggregation)])) async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): self._context.add_message( - glm.Content( + Content( role="model", parts=[ - glm.Part( - function_call=glm.FunctionCall( + Part( + function_call=FunctionCall( id=frame.tool_call_id, name=frame.function_name, args=frame.arguments ) ) @@ -99,11 +103,11 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): ) ) self._context.add_message( - glm.Content( + Content( role="user", parts=[ - glm.Part( - function_response=glm.FunctionResponse( + Part( + function_response=FunctionResponse( id=frame.tool_call_id, name=frame.function_name, response={"response": "IN_PROGRESS"}, @@ -187,7 +191,7 @@ class GoogleLLMContext(OpenAILLMContext): # Convert each message individually converted_messages = [] for msg in messages: - if isinstance(msg, glm.Content): + if isinstance(msg, Content): # Already in Gemini format converted_messages.append(msg) else: @@ -202,7 +206,7 @@ class GoogleLLMContext(OpenAILLMContext): def get_messages_for_logging(self): msgs = [] for message in self.messages: - obj = glm.Content.to_dict(message) + obj = message.to_json_dict() try: if "parts" in obj: for part in obj["parts"]: @@ -221,10 +225,10 @@ class GoogleLLMContext(OpenAILLMContext): parts = [] if text: - parts.append(glm.Part(text=text)) - parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue()))) + parts.append(Part(text=text)) + parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue()))) - self.add_message(glm.Content(role="user", parts=parts)) + self.add_message(Content(role="user", parts=parts)) def add_audio_frames_message( self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows" @@ -239,10 +243,10 @@ class GoogleLLMContext(OpenAILLMContext): data = b"".join(frame.audio for frame in audio_frames) # NOTE(aleix): According to the docs only text or inline_data should be needed. # (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference) - parts.append(glm.Part(text=text)) + parts.append(Part(text=text)) parts.append( - glm.Part( - inline_data=glm.Blob( + Part( + inline_data=Blob( mime_type="audio/wav", data=( bytes( @@ -252,7 +256,7 @@ class GoogleLLMContext(OpenAILLMContext): ) ), ) - self.add_message(glm.Content(role="user", parts=parts)) + self.add_message(Content(role="user", parts=parts)) # message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))} # self.add_message(message) @@ -271,7 +275,7 @@ class GoogleLLMContext(OpenAILLMContext): } Returns: - glm.Content object with: + Content object with: - role: "user" or "model" (converted from "assistant") - parts: List[Part] containing text, inline_data, or function calls Returns None for system messages. @@ -288,8 +292,8 @@ class GoogleLLMContext(OpenAILLMContext): if message.get("tool_calls"): for tc in message["tool_calls"]: parts.append( - glm.Part( - function_call=glm.FunctionCall( + Part( + function_call=FunctionCall( name=tc["function"]["name"], args=json.loads(tc["function"]["arguments"]), ) @@ -298,30 +302,30 @@ class GoogleLLMContext(OpenAILLMContext): elif role == "tool": role = "model" parts.append( - glm.Part( - function_response=glm.FunctionResponse( + Part( + function_response=FunctionResponse( name="tool_call_result", # seems to work to hard-code the same name every time response=json.loads(message["content"]), ) ) ) elif isinstance(content, str): - parts.append(glm.Part(text=content)) + parts.append(Part(text=content)) elif isinstance(content, list): for c in content: if c["type"] == "text": - parts.append(glm.Part(text=c["text"])) + parts.append(Part(text=c["text"])) elif c["type"] == "image_url": parts.append( - glm.Part( - inline_data=glm.Blob( + Part( + inline_data=Blob( mime_type="image/jpeg", data=base64.b64decode(c["image_url"]["url"].split(",")[1]), ) ) ) - message = glm.Content(role=role, parts=parts) + message = Content(role=role, parts=parts) return message def to_standard_messages(self, obj) -> list: @@ -409,7 +413,7 @@ class GoogleLLMContext(OpenAILLMContext): # Process each message, preserving Google-formatted messages and converting others for message in self._messages: - if isinstance(message, glm.Content): + if isinstance(message, Content): # Keep existing Google-formatted messages (e.g., function calls/responses) converted_messages.append(message) continue @@ -433,9 +437,7 @@ class GoogleLLMContext(OpenAILLMContext): # Add system message back as a user message if we only have function messages if self.system_message and not has_regular_messages: - self._messages.append( - glm.Content(role="user", parts=[glm.Part(text=self.system_message)]) - ) + self._messages.append(Content(role="user", parts=[Part(text=self.system_message)])) # Remove any empty messages self._messages = [m for m in self._messages if m.parts] @@ -463,7 +465,7 @@ class GoogleLLMService(LLMService): self, *, api_key: str, - model: str = "gemini-2.0-flash-001", + model: str = "gemini-2.0-flash", params: InputParams = InputParams(), system_instruction: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, @@ -471,10 +473,10 @@ class GoogleLLMService(LLMService): **kwargs, ): super().__init__(**kwargs) - gai.configure(api_key=api_key) self.set_model_name(model) + self._api_key = api_key self._system_instruction = system_instruction - self._create_client() + self._create_client(api_key) self._settings = { "max_tokens": params.max_tokens, "temperature": params.temperature, @@ -488,10 +490,8 @@ class GoogleLLMService(LLMService): def can_generate_metrics(self) -> bool: return True - def _create_client(self): - self._client = gai.GenerativeModel( - self._model_name, system_instruction=self._system_instruction - ) + def _create_client(self, api_key: str): + self._client = genai.Client(api_key=api_key) async def _process_context(self, context: OpenAILLMContext): await self.push_frame(LLMFullResponseStartFrame()) @@ -513,23 +513,7 @@ class GoogleLLMService(LLMService): if context.system_message and self._system_instruction != context.system_message: logger.debug(f"System instruction changed: {context.system_message}") self._system_instruction = context.system_message - self._create_client() - # Filter out None values and create GenerationConfig - generation_params = { - k: v - for k, v in { - "temperature": self._settings["temperature"], - "top_p": self._settings["top_p"], - "top_k": self._settings["top_k"], - "max_output_tokens": self._settings["max_tokens"], - }.items() - if v is not None - } - - generation_config = GenerationConfig(**generation_params) if generation_params else None - - await self.start_ttfb_metrics() tools = [] if context.tools: tools = context.tools @@ -538,112 +522,104 @@ class GoogleLLMService(LLMService): tool_config = None if self._tool_config: tool_config = self._tool_config - response = await self._client.generate_content_async( + + # Filter out None values and create GenerationContentConfig + generation_params = { + k: v + for k, v in { + "system_instruction": self._system_instruction, + "temperature": self._settings["temperature"], + "top_p": self._settings["top_p"], + "top_k": self._settings["top_k"], + "max_output_tokens": self._settings["max_tokens"], + "tools": tools, + "tool_config": tool_config, + }.items() + if v is not None + } + + generation_config = ( + GenerateContentConfig(**generation_params) if generation_params else None + ) + + await self.start_ttfb_metrics() + response = await self._client.aio.models.generate_content_stream( + model=self._model_name, contents=messages, - tools=tools, - stream=True, - generation_config=generation_config, - tool_config=tool_config, + config=generation_config, ) await self.stop_ttfb_metrics() - if response.usage_metadata: - # Use only the prompt token count from the response object - prompt_tokens = response.usage_metadata.prompt_token_count - total_tokens = prompt_tokens - async for chunk in response: if chunk.usage_metadata: - # Use only the completion_tokens from the chunks. Prompt tokens are already counted and - # are repeated here. - completion_tokens += chunk.usage_metadata.candidates_token_count - total_tokens += chunk.usage_metadata.candidates_token_count - try: - for c in chunk.parts: - if c.text: - search_result += c.text - await self.push_frame(LLMTextFrame(c.text)) - elif c.function_call: - logger.debug(f"Function call: {c.function_call}") - args = type(c.function_call).to_dict(c.function_call).get("args", {}) - await self.call_function( - context=context, - tool_call_id=str(uuid.uuid4()), - function_name=c.function_call.name, - arguments=args, - ) - # Handle grounding metadata - # It seems only the last chunk that we receive may contain this information - # If the response doesn't include groundingMetadata, this means the response wasn't grounded. - if chunk.candidates: - for candidate in chunk.candidates: - # logger.debug(f"candidate received: {candidate}") - # Extract grounding metadata - grounding_metadata = ( - { - "rendered_content": getattr( - getattr(candidate, "grounding_metadata", None), - "search_entry_point", - None, - ).rendered_content - if hasattr( - getattr(candidate, "grounding_metadata", None), - "search_entry_point", - ) - else None, - "origins": [ - { - "site_uri": getattr(grounding_chunk.web, "uri", None), - "site_title": getattr( - grounding_chunk.web, "title", None - ), - "results": [ - { - "text": getattr( - grounding_support.segment, "text", "" - ), - "confidence": getattr( - grounding_support, "confidence_scores", None - ), - } - for grounding_support in getattr( - getattr(candidate, "grounding_metadata", None), - "grounding_supports", - [], - ) - if index - in getattr( - grounding_support, "grounding_chunk_indices", [] - ) - ], - } - for index, grounding_chunk in enumerate( - getattr( - getattr(candidate, "grounding_metadata", None), - "grounding_chunks", - [], - ) - ) - ], - } - if getattr(candidate, "grounding_metadata", None) - else None - ) - except Exception as e: - # Google LLMs seem to flag safety issues a lot! - if chunk.candidates[0].finish_reason == 3: - logger.debug( - f"LLM refused to generate content for safety reasons - {messages}." - ) - else: - logger.exception(f"{self} error: {e}") + prompt_tokens += chunk.usage_metadata.prompt_token_count or 0 + completion_tokens += chunk.usage_metadata.candidates_token_count or 0 + total_tokens += chunk.usage_metadata.total_token_count or 0 + if not chunk.candidates: + continue + + for candidate in chunk.candidates: + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if not part.thought and part.text: + search_result += part.text + await self.push_frame(LLMTextFrame(part.text)) + elif part.function_call: + function_call = part.function_call + id = function_call.id or str(uuid.uuid4()) + logger.debug(f"Function call: {function_call.name}:{id}") + await self.call_function( + context=context, + tool_call_id=id, + function_name=function_call.name, + arguments=function_call.args or {}, + ) + + if ( + candidate.grounding_metadata + and candidate.grounding_metadata.grounding_chunks + ): + m = candidate.grounding_metadata + rendered_content = ( + m.search_entry_point.rendered_content if m.search_entry_point else None + ) + origins = [ + { + "site_uri": grounding_chunk.web.uri + if grounding_chunk.web + else None, + "site_title": grounding_chunk.web.title + if grounding_chunk.web + else None, + "results": [ + { + "text": grounding_support.segment.text + if grounding_support.segment + else "", + "confidence": grounding_support.confidence_scores, + } + for grounding_support in ( + m.grounding_supports if m.grounding_supports else [] + ) + if grounding_support.grounding_chunk_indices + and index in grounding_support.grounding_chunk_indices + ], + } + for index, grounding_chunk in enumerate( + m.grounding_chunks if m.grounding_chunks else [] + ) + ] + grounding_metadata = { + "rendered_content": rendered_content, + "origins": origins, + } except DeadlineExceeded: await self._call_event_handler("on_completion_timeout") except Exception as e: logger.exception(f"{self} exception: {e}") finally: - if grounding_metadata is not None and isinstance(grounding_metadata, dict): + if grounding_metadata and isinstance(grounding_metadata, dict): llm_search_frame = LLMSearchResponseFrame( search_result=search_result, origins=grounding_metadata["origins"], diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index 0f68110ce..75a81aaac 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -8,8 +8,6 @@ import json import unittest from typing import Any -import google.ai.generativelanguage as glm - from pipecat.frames.frames import ( EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame, @@ -758,13 +756,13 @@ class TestGoogleUserContextAggregator( AGGREGATOR_CLASS = GoogleUserContextAggregator def check_message_content(self, context: OpenAILLMContext, index: int, content: str): - obj = glm.Content.to_dict(context.messages[index]) + obj = context.messages[index].to_json_dict() assert obj["parts"][0]["text"] == content def check_message_multi_content( self, context: OpenAILLMContext, content_index: int, index: int, content: str ): - obj = glm.Content.to_dict(context.messages[index]) + obj = context.messages[index].to_json_dict() assert obj["parts"][0]["text"] == content @@ -776,17 +774,17 @@ class TestGoogleAssistantContextAggregator( EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] def check_message_content(self, context: OpenAILLMContext, index: int, content: str): - obj = glm.Content.to_dict(context.messages[index]) + obj = context.messages[index].to_json_dict() assert obj["parts"][0]["text"] == content def check_message_multi_content( self, context: OpenAILLMContext, content_index: int, index: int, content: str ): - obj = glm.Content.to_dict(context.messages[index]) + obj = context.messages[index].to_json_dict() assert obj["parts"][0]["text"] == content def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): - obj = glm.Content.to_dict(context.messages[index]) + obj = context.messages[index].to_json_dict() assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)