Merge pull request #1780 from pipecat-ai/aleix/deprecate-google-generativeai
GoogleLLMService: deprecate google-generativeai
This commit is contained in:
@@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
|
||||
- `GoogleLLMService` has been updated to use `google-genai` instead of the
|
||||
deprecated `google-generativeai`.
|
||||
|
||||
## [0.0.67] - 2025-05-07
|
||||
|
||||
### Added
|
||||
|
||||
@@ -11,18 +11,17 @@ from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai import audio
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.google.llm import GoogleLLMService, LLMSearchResponseFrame
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
@@ -33,7 +32,7 @@ load_dotenv(override=True)
|
||||
|
||||
|
||||
# Function handlers for the LLM
|
||||
search_tool = {"google_search_retrieval": {}}
|
||||
search_tool = {"google_search": {}}
|
||||
tools = [search_tool]
|
||||
|
||||
system_instruction = """
|
||||
@@ -50,14 +49,22 @@ Start each interaction by asking the user about which place they would like to k
|
||||
"""
|
||||
|
||||
|
||||
class LLMSearchLoggerProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
class LLMSearchLoggerObserver(BaseObserver):
|
||||
async def on_push_frame(self, data: FramePushed):
|
||||
src = data.source
|
||||
dst = data.destination
|
||||
frame = data.frame
|
||||
timestamp = data.timestamp
|
||||
|
||||
if not isinstance(src, LLMService) and not isinstance(dst, LLMService):
|
||||
return
|
||||
|
||||
time_sec = timestamp / 1_000_000_000
|
||||
|
||||
arrow = "→"
|
||||
|
||||
if isinstance(frame, LLMSearchResponseFrame):
|
||||
print(f"LLMSearchLoggerProcessor: {frame}")
|
||||
|
||||
await self.push_frame(frame)
|
||||
logger.debug(f"🧠 {arrow} {dst} LLM SEARCH RESPONSE FRAME: {frame} at {time_sec:.2f}s")
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
|
||||
@@ -84,7 +91,6 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_instruction,
|
||||
tools=tools,
|
||||
model="gemini-1.5-flash-002",
|
||||
)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
@@ -97,22 +103,23 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
llm_search_logger = LLMSearchLoggerProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
llm_search_logger,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(allow_interruptions=True),
|
||||
observers=[LLMSearchLoggerObserver()],
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
|
||||
@@ -102,9 +102,9 @@ async def main():
|
||||
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-1.5-flash-002",
|
||||
system_instruction=system_instruction,
|
||||
tools=tools,
|
||||
model="gemini-1.5-flash",
|
||||
)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
@@ -153,7 +153,6 @@ async def main():
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
logger.debug("First participant joined: {}", participant["id"])
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
|
||||
@@ -187,7 +187,7 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
print(f"Participant joined: {participant}")
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
|
||||
@@ -215,6 +215,7 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
print(f"Participant joined: {participant}")
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
|
||||
@@ -30,7 +30,7 @@ from loguru import logger
|
||||
from pipecatcloud.agent import DailySessionArguments
|
||||
from word_list import generate_game_words
|
||||
|
||||
from pipecat.audio.resamplers.soxr_resampler import SOXRAudioResampler
|
||||
from pipecat.audio.utils import create_default_resampler
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
@@ -524,7 +524,7 @@ async def tts_audio_raw_frame_filter(frame: Frame):
|
||||
|
||||
|
||||
# Create a resampler instance once
|
||||
resampler = SOXRAudioResampler()
|
||||
resampler = create_default_resampler()
|
||||
|
||||
|
||||
async def tts_to_input_audio_transformer(frame: Frame):
|
||||
@@ -689,8 +689,6 @@ Important guidelines:
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
logger.info("First participant joined: {}", participant["id"])
|
||||
# Capture the participant's transcription
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
# Start the game timer
|
||||
|
||||
@@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||
gladia = [ "websockets~=13.1" ]
|
||||
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.14.0", "websockets~=13.1" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
|
||||
@@ -52,10 +52,16 @@ from pipecat.services.openai.llm import (
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
try:
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as gai
|
||||
from google import genai
|
||||
from google.api_core.exceptions import DeadlineExceeded
|
||||
from google.generativeai.types import GenerationConfig
|
||||
from google.genai.types import (
|
||||
Blob,
|
||||
Content,
|
||||
FunctionCall,
|
||||
FunctionResponse,
|
||||
GenerateContentConfig,
|
||||
Part,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
@@ -65,9 +71,7 @@ except ModuleNotFoundError as e:
|
||||
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message(
|
||||
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
|
||||
)
|
||||
self._context.add_message(Content(role="user", parts=[Part(text=self._aggregation)]))
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
@@ -83,15 +87,15 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def handle_aggregation(self, aggregation: str):
|
||||
self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)]))
|
||||
self._context.add_message(Content(role="model", parts=[Part(text=aggregation)]))
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
self._context.add_message(
|
||||
glm.Content(
|
||||
Content(
|
||||
role="model",
|
||||
parts=[
|
||||
glm.Part(
|
||||
function_call=glm.FunctionCall(
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
|
||||
)
|
||||
)
|
||||
@@ -99,11 +103,11 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
)
|
||||
)
|
||||
self._context.add_message(
|
||||
glm.Content(
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
id=frame.tool_call_id,
|
||||
name=frame.function_name,
|
||||
response={"response": "IN_PROGRESS"},
|
||||
@@ -187,7 +191,7 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
# Convert each message individually
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, glm.Content):
|
||||
if isinstance(msg, Content):
|
||||
# Already in Gemini format
|
||||
converted_messages.append(msg)
|
||||
else:
|
||||
@@ -202,7 +206,7 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
def get_messages_for_logging(self):
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
obj = glm.Content.to_dict(message)
|
||||
obj = message.to_json_dict()
|
||||
try:
|
||||
if "parts" in obj:
|
||||
for part in obj["parts"]:
|
||||
@@ -221,10 +225,10 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(glm.Part(text=text))
|
||||
parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())))
|
||||
parts.append(Part(text=text))
|
||||
parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue())))
|
||||
|
||||
self.add_message(glm.Content(role="user", parts=parts))
|
||||
self.add_message(Content(role="user", parts=parts))
|
||||
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
@@ -239,10 +243,10 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
# NOTE(aleix): According to the docs only text or inline_data should be needed.
|
||||
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
|
||||
parts.append(glm.Part(text=text))
|
||||
parts.append(Part(text=text))
|
||||
parts.append(
|
||||
glm.Part(
|
||||
inline_data=glm.Blob(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="audio/wav",
|
||||
data=(
|
||||
bytes(
|
||||
@@ -252,7 +256,7 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
)
|
||||
),
|
||||
)
|
||||
self.add_message(glm.Content(role="user", parts=parts))
|
||||
self.add_message(Content(role="user", parts=parts))
|
||||
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
|
||||
# self.add_message(message)
|
||||
|
||||
@@ -271,7 +275,7 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
}
|
||||
|
||||
Returns:
|
||||
glm.Content object with:
|
||||
Content object with:
|
||||
- role: "user" or "model" (converted from "assistant")
|
||||
- parts: List[Part] containing text, inline_data, or function calls
|
||||
Returns None for system messages.
|
||||
@@ -288,8 +292,8 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
parts.append(
|
||||
glm.Part(
|
||||
function_call=glm.FunctionCall(
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
args=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
@@ -298,30 +302,30 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
parts.append(
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name="tool_call_result", # seems to work to hard-code the same name every time
|
||||
response=json.loads(message["content"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(content, str):
|
||||
parts.append(glm.Part(text=content))
|
||||
parts.append(Part(text=content))
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c["type"] == "text":
|
||||
parts.append(glm.Part(text=c["text"]))
|
||||
parts.append(Part(text=c["text"]))
|
||||
elif c["type"] == "image_url":
|
||||
parts.append(
|
||||
glm.Part(
|
||||
inline_data=glm.Blob(
|
||||
Part(
|
||||
inline_data=Blob(
|
||||
mime_type="image/jpeg",
|
||||
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
message = glm.Content(role=role, parts=parts)
|
||||
message = Content(role=role, parts=parts)
|
||||
return message
|
||||
|
||||
def to_standard_messages(self, obj) -> list:
|
||||
@@ -409,7 +413,7 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
|
||||
# Process each message, preserving Google-formatted messages and converting others
|
||||
for message in self._messages:
|
||||
if isinstance(message, glm.Content):
|
||||
if isinstance(message, Content):
|
||||
# Keep existing Google-formatted messages (e.g., function calls/responses)
|
||||
converted_messages.append(message)
|
||||
continue
|
||||
@@ -433,9 +437,7 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
|
||||
# Add system message back as a user message if we only have function messages
|
||||
if self.system_message and not has_regular_messages:
|
||||
self._messages.append(
|
||||
glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
|
||||
)
|
||||
self._messages.append(Content(role="user", parts=[Part(text=self.system_message)]))
|
||||
|
||||
# Remove any empty messages
|
||||
self._messages = [m for m in self._messages if m.parts]
|
||||
@@ -463,7 +465,7 @@ class GoogleLLMService(LLMService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
model: str = "gemini-2.0-flash",
|
||||
params: InputParams = InputParams(),
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
@@ -471,10 +473,10 @@ class GoogleLLMService(LLMService):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
gai.configure(api_key=api_key)
|
||||
self.set_model_name(model)
|
||||
self._api_key = api_key
|
||||
self._system_instruction = system_instruction
|
||||
self._create_client()
|
||||
self._create_client(api_key)
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
@@ -488,10 +490,8 @@ class GoogleLLMService(LLMService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def _create_client(self):
|
||||
self._client = gai.GenerativeModel(
|
||||
self._model_name, system_instruction=self._system_instruction
|
||||
)
|
||||
def _create_client(self, api_key: str):
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
@@ -513,23 +513,7 @@ class GoogleLLMService(LLMService):
|
||||
if context.system_message and self._system_instruction != context.system_message:
|
||||
logger.debug(f"System instruction changed: {context.system_message}")
|
||||
self._system_instruction = context.system_message
|
||||
self._create_client()
|
||||
|
||||
# Filter out None values and create GenerationConfig
|
||||
generation_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"max_output_tokens": self._settings["max_tokens"],
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
generation_config = GenerationConfig(**generation_params) if generation_params else None
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
tools = []
|
||||
if context.tools:
|
||||
tools = context.tools
|
||||
@@ -538,112 +522,104 @@ class GoogleLLMService(LLMService):
|
||||
tool_config = None
|
||||
if self._tool_config:
|
||||
tool_config = self._tool_config
|
||||
response = await self._client.generate_content_async(
|
||||
|
||||
# Filter out None values and create GenerationContentConfig
|
||||
generation_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"system_instruction": self._system_instruction,
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"max_output_tokens": self._settings["max_tokens"],
|
||||
"tools": tools,
|
||||
"tool_config": tool_config,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
generation_config = (
|
||||
GenerateContentConfig(**generation_params) if generation_params else None
|
||||
)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
response = await self._client.aio.models.generate_content_stream(
|
||||
model=self._model_name,
|
||||
contents=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
generation_config=generation_config,
|
||||
tool_config=tool_config,
|
||||
config=generation_config,
|
||||
)
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if response.usage_metadata:
|
||||
# Use only the prompt token count from the response object
|
||||
prompt_tokens = response.usage_metadata.prompt_token_count
|
||||
total_tokens = prompt_tokens
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.usage_metadata:
|
||||
# Use only the completion_tokens from the chunks. Prompt tokens are already counted and
|
||||
# are repeated here.
|
||||
completion_tokens += chunk.usage_metadata.candidates_token_count
|
||||
total_tokens += chunk.usage_metadata.candidates_token_count
|
||||
try:
|
||||
for c in chunk.parts:
|
||||
if c.text:
|
||||
search_result += c.text
|
||||
await self.push_frame(LLMTextFrame(c.text))
|
||||
elif c.function_call:
|
||||
logger.debug(f"Function call: {c.function_call}")
|
||||
args = type(c.function_call).to_dict(c.function_call).get("args", {})
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
function_name=c.function_call.name,
|
||||
arguments=args,
|
||||
)
|
||||
# Handle grounding metadata
|
||||
# It seems only the last chunk that we receive may contain this information
|
||||
# If the response doesn't include groundingMetadata, this means the response wasn't grounded.
|
||||
if chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
# logger.debug(f"candidate received: {candidate}")
|
||||
# Extract grounding metadata
|
||||
grounding_metadata = (
|
||||
{
|
||||
"rendered_content": getattr(
|
||||
getattr(candidate, "grounding_metadata", None),
|
||||
"search_entry_point",
|
||||
None,
|
||||
).rendered_content
|
||||
if hasattr(
|
||||
getattr(candidate, "grounding_metadata", None),
|
||||
"search_entry_point",
|
||||
)
|
||||
else None,
|
||||
"origins": [
|
||||
{
|
||||
"site_uri": getattr(grounding_chunk.web, "uri", None),
|
||||
"site_title": getattr(
|
||||
grounding_chunk.web, "title", None
|
||||
),
|
||||
"results": [
|
||||
{
|
||||
"text": getattr(
|
||||
grounding_support.segment, "text", ""
|
||||
),
|
||||
"confidence": getattr(
|
||||
grounding_support, "confidence_scores", None
|
||||
),
|
||||
}
|
||||
for grounding_support in getattr(
|
||||
getattr(candidate, "grounding_metadata", None),
|
||||
"grounding_supports",
|
||||
[],
|
||||
)
|
||||
if index
|
||||
in getattr(
|
||||
grounding_support, "grounding_chunk_indices", []
|
||||
)
|
||||
],
|
||||
}
|
||||
for index, grounding_chunk in enumerate(
|
||||
getattr(
|
||||
getattr(candidate, "grounding_metadata", None),
|
||||
"grounding_chunks",
|
||||
[],
|
||||
)
|
||||
)
|
||||
],
|
||||
}
|
||||
if getattr(candidate, "grounding_metadata", None)
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
# Google LLMs seem to flag safety issues a lot!
|
||||
if chunk.candidates[0].finish_reason == 3:
|
||||
logger.debug(
|
||||
f"LLM refused to generate content for safety reasons - {messages}."
|
||||
)
|
||||
else:
|
||||
logger.exception(f"{self} error: {e}")
|
||||
prompt_tokens += chunk.usage_metadata.prompt_token_count or 0
|
||||
completion_tokens += chunk.usage_metadata.candidates_token_count or 0
|
||||
total_tokens += chunk.usage_metadata.total_token_count or 0
|
||||
|
||||
if not chunk.candidates:
|
||||
continue
|
||||
|
||||
for candidate in chunk.candidates:
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if not part.thought and part.text:
|
||||
search_result += part.text
|
||||
await self.push_frame(LLMTextFrame(part.text))
|
||||
elif part.function_call:
|
||||
function_call = part.function_call
|
||||
id = function_call.id or str(uuid.uuid4())
|
||||
logger.debug(f"Function call: {function_call.name}:{id}")
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=id,
|
||||
function_name=function_call.name,
|
||||
arguments=function_call.args or {},
|
||||
)
|
||||
|
||||
if (
|
||||
candidate.grounding_metadata
|
||||
and candidate.grounding_metadata.grounding_chunks
|
||||
):
|
||||
m = candidate.grounding_metadata
|
||||
rendered_content = (
|
||||
m.search_entry_point.rendered_content if m.search_entry_point else None
|
||||
)
|
||||
origins = [
|
||||
{
|
||||
"site_uri": grounding_chunk.web.uri
|
||||
if grounding_chunk.web
|
||||
else None,
|
||||
"site_title": grounding_chunk.web.title
|
||||
if grounding_chunk.web
|
||||
else None,
|
||||
"results": [
|
||||
{
|
||||
"text": grounding_support.segment.text
|
||||
if grounding_support.segment
|
||||
else "",
|
||||
"confidence": grounding_support.confidence_scores,
|
||||
}
|
||||
for grounding_support in (
|
||||
m.grounding_supports if m.grounding_supports else []
|
||||
)
|
||||
if grounding_support.grounding_chunk_indices
|
||||
and index in grounding_support.grounding_chunk_indices
|
||||
],
|
||||
}
|
||||
for index, grounding_chunk in enumerate(
|
||||
m.grounding_chunks if m.grounding_chunks else []
|
||||
)
|
||||
]
|
||||
grounding_metadata = {
|
||||
"rendered_content": rendered_content,
|
||||
"origins": origins,
|
||||
}
|
||||
except DeadlineExceeded:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
if grounding_metadata is not None and isinstance(grounding_metadata, dict):
|
||||
if grounding_metadata and isinstance(grounding_metadata, dict):
|
||||
llm_search_frame = LLMSearchResponseFrame(
|
||||
search_result=search_result,
|
||||
origins=grounding_metadata["origins"],
|
||||
|
||||
@@ -8,8 +8,6 @@ import json
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
@@ -758,13 +756,13 @@ class TestGoogleUserContextAggregator(
|
||||
AGGREGATOR_CLASS = GoogleUserContextAggregator
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
obj = context.messages[index].to_json_dict()
|
||||
assert obj["parts"][0]["text"] == content
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
obj = context.messages[index].to_json_dict()
|
||||
assert obj["parts"][0]["text"] == content
|
||||
|
||||
|
||||
@@ -776,17 +774,17 @@ class TestGoogleAssistantContextAggregator(
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
obj = context.messages[index].to_json_dict()
|
||||
assert obj["parts"][0]["text"] == content
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
obj = context.messages[index].to_json_dict()
|
||||
assert obj["parts"][0]["text"] == content
|
||||
|
||||
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
obj = context.messages[index].to_json_dict()
|
||||
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user