Update GeminiMultimodalLiveLLMService to use the google-genai library, which is the new recommended approach for interfacing with Gemini Live.
This commit is contained in:
@@ -6,525 +6,22 @@
|
||||
|
||||
"""Event models and utilities for Google Gemini Multimodal Live API."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import ImageRawFrame
|
||||
|
||||
#
|
||||
# Client events
|
||||
#
|
||||
|
||||
|
||||
class MediaChunk(BaseModel):
|
||||
"""Represents a chunk of media data for transmission.
|
||||
|
||||
Parameters:
|
||||
mimeType: MIME type of the media content.
|
||||
data: Base64-encoded media data.
|
||||
"""
|
||||
|
||||
mimeType: str
|
||||
data: str
|
||||
|
||||
|
||||
class ContentPart(BaseModel):
|
||||
"""Represents a part of content that can contain text or media.
|
||||
|
||||
Parameters:
|
||||
text: Text content. Defaults to None.
|
||||
inlineData: Inline media data. Defaults to None.
|
||||
"""
|
||||
|
||||
text: Optional[str] = Field(default=None, validate_default=False)
|
||||
inlineData: Optional[MediaChunk] = Field(default=None, validate_default=False)
|
||||
fileData: Optional["FileData"] = Field(default=None, validate_default=False)
|
||||
|
||||
|
||||
class FileData(BaseModel):
|
||||
"""Represents a file reference in the Gemini File API."""
|
||||
|
||||
mimeType: str
|
||||
fileUri: str
|
||||
|
||||
|
||||
ContentPart.model_rebuild() # Rebuild model to resolve forward reference
|
||||
|
||||
|
||||
class Turn(BaseModel):
|
||||
"""Represents a conversational turn in the dialogue.
|
||||
|
||||
Parameters:
|
||||
role: The role of the speaker, either "user" or "model". Defaults to "user".
|
||||
parts: List of content parts that make up the turn.
|
||||
"""
|
||||
|
||||
role: Literal["user", "model"] = "user"
|
||||
parts: List[ContentPart]
|
||||
|
||||
|
||||
class StartSensitivity(str, Enum):
|
||||
"""Determines how start of speech is detected."""
|
||||
|
||||
UNSPECIFIED = "START_SENSITIVITY_UNSPECIFIED" # Default is HIGH
|
||||
HIGH = "START_SENSITIVITY_HIGH" # Detect start of speech more often
|
||||
LOW = "START_SENSITIVITY_LOW" # Detect start of speech less often
|
||||
|
||||
|
||||
class EndSensitivity(str, Enum):
|
||||
"""Determines how end of speech is detected."""
|
||||
|
||||
UNSPECIFIED = "END_SENSITIVITY_UNSPECIFIED" # Default is HIGH
|
||||
HIGH = "END_SENSITIVITY_HIGH" # End speech more often
|
||||
LOW = "END_SENSITIVITY_LOW" # End speech less often
|
||||
|
||||
|
||||
class AutomaticActivityDetection(BaseModel):
|
||||
"""Configures automatic detection of voice activity.
|
||||
|
||||
Parameters:
|
||||
disabled: Whether automatic activity detection is disabled. Defaults to None.
|
||||
start_of_speech_sensitivity: Sensitivity for detecting speech start. Defaults to None.
|
||||
prefix_padding_ms: Padding before speech start in milliseconds. Defaults to None.
|
||||
end_of_speech_sensitivity: Sensitivity for detecting speech end. Defaults to None.
|
||||
silence_duration_ms: Duration of silence to detect speech end. Defaults to None.
|
||||
"""
|
||||
|
||||
disabled: Optional[bool] = None
|
||||
start_of_speech_sensitivity: Optional[StartSensitivity] = None
|
||||
prefix_padding_ms: Optional[int] = None
|
||||
end_of_speech_sensitivity: Optional[EndSensitivity] = None
|
||||
silence_duration_ms: Optional[int] = None
|
||||
|
||||
|
||||
class RealtimeInputConfig(BaseModel):
|
||||
"""Configures the realtime input behavior.
|
||||
|
||||
Parameters:
|
||||
automatic_activity_detection: Voice activity detection configuration. Defaults to None.
|
||||
"""
|
||||
|
||||
automatic_activity_detection: Optional[AutomaticActivityDetection] = None
|
||||
|
||||
|
||||
class RealtimeInput(BaseModel):
|
||||
"""Contains realtime input media chunks and text.
|
||||
|
||||
Parameters:
|
||||
mediaChunks: List of media chunks for realtime processing.
|
||||
text: Text for realtime processing.
|
||||
"""
|
||||
|
||||
mediaChunks: Optional[List[MediaChunk]] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class ClientContent(BaseModel):
|
||||
"""Content sent from client to the Gemini Live API.
|
||||
|
||||
Parameters:
|
||||
turns: List of conversation turns. Defaults to None.
|
||||
turnComplete: Whether the client's turn is complete. Defaults to False.
|
||||
"""
|
||||
|
||||
turns: Optional[List[Turn]] = None
|
||||
turnComplete: bool = False
|
||||
|
||||
|
||||
class AudioInputMessage(BaseModel):
|
||||
"""Message containing audio input data.
|
||||
|
||||
Parameters:
|
||||
realtimeInput: Realtime input containing audio chunks.
|
||||
"""
|
||||
|
||||
realtimeInput: RealtimeInput
|
||||
|
||||
@classmethod
|
||||
def from_raw_audio(cls, raw_audio: bytes, sample_rate: int) -> "AudioInputMessage":
|
||||
"""Create an audio input message from raw audio data.
|
||||
|
||||
Args:
|
||||
raw_audio: Raw audio bytes.
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
|
||||
Returns:
|
||||
AudioInputMessage instance with encoded audio data.
|
||||
"""
|
||||
data = base64.b64encode(raw_audio).decode("utf-8")
|
||||
return cls(
|
||||
realtimeInput=RealtimeInput(
|
||||
mediaChunks=[MediaChunk(mimeType=f"audio/pcm;rate={sample_rate}", data=data)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class VideoInputMessage(BaseModel):
|
||||
"""Message containing video/image input data.
|
||||
|
||||
Parameters:
|
||||
realtimeInput: Realtime input containing video/image chunks.
|
||||
"""
|
||||
|
||||
realtimeInput: RealtimeInput
|
||||
|
||||
@classmethod
|
||||
def from_image_frame(cls, frame: ImageRawFrame) -> "VideoInputMessage":
|
||||
"""Create a video input message from an image frame.
|
||||
|
||||
Args:
|
||||
frame: Image frame to encode.
|
||||
|
||||
Returns:
|
||||
VideoInputMessage instance with encoded image data.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(frame.format, frame.size, frame.image).save(buffer, format="JPEG")
|
||||
data = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return cls(
|
||||
realtimeInput=RealtimeInput(mediaChunks=[MediaChunk(mimeType=f"image/jpeg", data=data)])
|
||||
)
|
||||
|
||||
|
||||
class TextInputMessage(BaseModel):
|
||||
"""Message containing text input data."""
|
||||
|
||||
realtimeInput: RealtimeInput
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str) -> "TextInputMessage":
|
||||
"""Create a text input message from a string.
|
||||
|
||||
Args:
|
||||
text: The text to send.
|
||||
|
||||
Returns:
|
||||
A TextInputMessage instance.
|
||||
"""
|
||||
return cls(realtimeInput=RealtimeInput(text=text))
|
||||
|
||||
|
||||
class ClientContentMessage(BaseModel):
|
||||
"""Message containing client content for the API.
|
||||
|
||||
Parameters:
|
||||
clientContent: The client content to send.
|
||||
"""
|
||||
|
||||
clientContent: ClientContent
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
"""System instruction for the model.
|
||||
|
||||
Parameters:
|
||||
parts: List of content parts that make up the system instruction.
|
||||
"""
|
||||
|
||||
parts: List[ContentPart]
|
||||
|
||||
|
||||
class AudioTranscriptionConfig(BaseModel):
|
||||
"""Configuration for audio transcription."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Setup(BaseModel):
|
||||
"""Setup configuration for the Gemini Live session.
|
||||
|
||||
Parameters:
|
||||
model: Model identifier to use.
|
||||
system_instruction: System instruction for the model. Defaults to None.
|
||||
tools: List of available tools/functions. Defaults to None.
|
||||
generation_config: Generation configuration parameters. Defaults to None.
|
||||
input_audio_transcription: Input audio transcription config. Defaults to None.
|
||||
output_audio_transcription: Output audio transcription config. Defaults to None.
|
||||
realtime_input_config: Realtime input configuration. Defaults to None.
|
||||
"""
|
||||
|
||||
model: str
|
||||
system_instruction: Optional[SystemInstruction] = None
|
||||
tools: Optional[List[dict]] = None
|
||||
generation_config: Optional[dict] = None
|
||||
input_audio_transcription: Optional[AudioTranscriptionConfig] = None
|
||||
output_audio_transcription: Optional[AudioTranscriptionConfig] = None
|
||||
realtime_input_config: Optional[RealtimeInputConfig] = None
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""Configuration message for session setup.
|
||||
|
||||
Parameters:
|
||||
setup: Setup configuration for the session.
|
||||
"""
|
||||
|
||||
setup: Setup
|
||||
|
||||
|
||||
#
|
||||
# Grounding metadata models
|
||||
#
|
||||
|
||||
|
||||
class SearchEntryPoint(BaseModel):
|
||||
"""Represents the search entry point with rendered content for search suggestions."""
|
||||
|
||||
renderedContent: Optional[str] = None
|
||||
|
||||
|
||||
class WebSource(BaseModel):
|
||||
"""Represents a web source from grounding chunks."""
|
||||
|
||||
uri: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class GroundingChunk(BaseModel):
|
||||
"""Represents a grounding chunk containing web source information."""
|
||||
|
||||
web: Optional[WebSource] = None
|
||||
|
||||
|
||||
class GroundingSegment(BaseModel):
|
||||
"""Represents a segment of text that is grounded."""
|
||||
|
||||
startIndex: Optional[int] = None
|
||||
endIndex: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class GroundingSupport(BaseModel):
|
||||
"""Represents support information for grounded text segments."""
|
||||
|
||||
segment: Optional[GroundingSegment] = None
|
||||
groundingChunkIndices: Optional[List[int]] = None
|
||||
confidenceScores: Optional[List[float]] = None
|
||||
|
||||
|
||||
class GroundingMetadata(BaseModel):
|
||||
"""Represents grounding metadata from Google Search."""
|
||||
|
||||
searchEntryPoint: Optional[SearchEntryPoint] = None
|
||||
groundingChunks: Optional[List[GroundingChunk]] = None
|
||||
groundingSupports: Optional[List[GroundingSupport]] = None
|
||||
webSearchQueries: Optional[List[str]] = None
|
||||
|
||||
|
||||
#
|
||||
# Server events
|
||||
#
|
||||
|
||||
|
||||
class SetupComplete(BaseModel):
|
||||
"""Indicates that session setup is complete."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InlineData(BaseModel):
|
||||
"""Inline data embedded in server responses.
|
||||
|
||||
Parameters:
|
||||
mimeType: MIME type of the data.
|
||||
data: Base64-encoded data content.
|
||||
"""
|
||||
|
||||
mimeType: str
|
||||
data: str
|
||||
|
||||
|
||||
class Part(BaseModel):
|
||||
"""Part of a server response containing data or text.
|
||||
|
||||
Parameters:
|
||||
inlineData: Inline binary data. Defaults to None.
|
||||
text: Text content. Defaults to None.
|
||||
"""
|
||||
|
||||
inlineData: Optional[InlineData] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class ModelTurn(BaseModel):
|
||||
"""Represents a turn from the model in the conversation.
|
||||
|
||||
Parameters:
|
||||
parts: List of content parts in the model's response.
|
||||
"""
|
||||
|
||||
parts: List[Part]
|
||||
|
||||
|
||||
class ServerContentInterrupted(BaseModel):
|
||||
"""Indicates server content was interrupted.
|
||||
|
||||
Parameters:
|
||||
interrupted: Whether the content was interrupted.
|
||||
"""
|
||||
|
||||
interrupted: bool
|
||||
|
||||
|
||||
class ServerContentTurnComplete(BaseModel):
|
||||
"""Indicates the server's turn is complete.
|
||||
|
||||
Parameters:
|
||||
turnComplete: Whether the turn is complete.
|
||||
"""
|
||||
|
||||
turnComplete: bool
|
||||
|
||||
|
||||
class BidiGenerateContentTranscription(BaseModel):
|
||||
"""Transcription data from bidirectional content generation.
|
||||
|
||||
Parameters:
|
||||
text: The transcribed text content.
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class ServerContent(BaseModel):
|
||||
"""Content sent from server to client.
|
||||
|
||||
Parameters:
|
||||
modelTurn: Model's conversational turn. Defaults to None.
|
||||
interrupted: Whether content was interrupted. Defaults to None.
|
||||
turnComplete: Whether the turn is complete. Defaults to None.
|
||||
inputTranscription: Transcription of input audio. Defaults to None.
|
||||
outputTranscription: Transcription of output audio. Defaults to None.
|
||||
"""
|
||||
|
||||
modelTurn: Optional[ModelTurn] = None
|
||||
interrupted: Optional[bool] = None
|
||||
turnComplete: Optional[bool] = None
|
||||
inputTranscription: Optional[BidiGenerateContentTranscription] = None
|
||||
outputTranscription: Optional[BidiGenerateContentTranscription] = None
|
||||
groundingMetadata: Optional[GroundingMetadata] = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
"""Represents a function call from the model.
|
||||
|
||||
Parameters:
|
||||
id: Unique identifier for the function call.
|
||||
name: Name of the function to call.
|
||||
args: Arguments to pass to the function.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
args: dict
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Contains one or more function calls.
|
||||
|
||||
Parameters:
|
||||
functionCalls: List of function calls to execute.
|
||||
"""
|
||||
|
||||
functionCalls: List[FunctionCall]
|
||||
|
||||
|
||||
class Modality(str, Enum):
|
||||
"""Modality types in token counts."""
|
||||
|
||||
UNSPECIFIED = "MODALITY_UNSPECIFIED"
|
||||
TEXT = "TEXT"
|
||||
IMAGE = "IMAGE"
|
||||
AUDIO = "AUDIO"
|
||||
VIDEO = "VIDEO"
|
||||
|
||||
|
||||
class ModalityTokenCount(BaseModel):
|
||||
"""Token count for a specific modality.
|
||||
|
||||
Parameters:
|
||||
modality: The modality type.
|
||||
tokenCount: Number of tokens for this modality.
|
||||
"""
|
||||
|
||||
modality: Modality
|
||||
tokenCount: int
|
||||
|
||||
|
||||
class UsageMetadata(BaseModel):
|
||||
"""Usage metadata about the API response.
|
||||
|
||||
Parameters:
|
||||
promptTokenCount: Number of tokens in the prompt. Defaults to None.
|
||||
cachedContentTokenCount: Number of cached content tokens. Defaults to None.
|
||||
responseTokenCount: Number of tokens in the response. Defaults to None.
|
||||
toolUsePromptTokenCount: Number of tokens for tool use prompts. Defaults to None.
|
||||
thoughtsTokenCount: Number of tokens for model thoughts. Defaults to None.
|
||||
totalTokenCount: Total number of tokens used. Defaults to None.
|
||||
promptTokensDetails: Detailed breakdown of prompt tokens by modality. Defaults to None.
|
||||
cacheTokensDetails: Detailed breakdown of cache tokens by modality. Defaults to None.
|
||||
responseTokensDetails: Detailed breakdown of response tokens by modality. Defaults to None.
|
||||
toolUsePromptTokensDetails: Detailed breakdown of tool use tokens by modality. Defaults to None.
|
||||
"""
|
||||
|
||||
promptTokenCount: Optional[int] = None
|
||||
cachedContentTokenCount: Optional[int] = None
|
||||
responseTokenCount: Optional[int] = None
|
||||
toolUsePromptTokenCount: Optional[int] = None
|
||||
thoughtsTokenCount: Optional[int] = None
|
||||
totalTokenCount: Optional[int] = None
|
||||
promptTokensDetails: Optional[List[ModalityTokenCount]] = None
|
||||
cacheTokensDetails: Optional[List[ModalityTokenCount]] = None
|
||||
responseTokensDetails: Optional[List[ModalityTokenCount]] = None
|
||||
toolUsePromptTokensDetails: Optional[List[ModalityTokenCount]] = None
|
||||
|
||||
|
||||
class ServerEvent(BaseModel):
|
||||
"""Server event received from the Gemini Live API.
|
||||
|
||||
Parameters:
|
||||
setupComplete: Setup completion notification. Defaults to None.
|
||||
serverContent: Content from the server. Defaults to None.
|
||||
toolCall: Tool/function call request. Defaults to None.
|
||||
usageMetadata: Token usage metadata. Defaults to None.
|
||||
"""
|
||||
|
||||
setupComplete: Optional[SetupComplete] = None
|
||||
serverContent: Optional[ServerContent] = None
|
||||
toolCall: Optional[ToolCall] = None
|
||||
usageMetadata: Optional[UsageMetadata] = None
|
||||
|
||||
|
||||
def parse_server_event(str):
|
||||
"""Parse a server event from JSON string.
|
||||
|
||||
Args:
|
||||
str: JSON string containing the server event.
|
||||
|
||||
Returns:
|
||||
ServerEvent instance if parsing succeeds, None otherwise.
|
||||
"""
|
||||
try:
|
||||
evt = json.loads(str)
|
||||
return ServerEvent.model_validate(evt)
|
||||
except Exception as e:
|
||||
print(f"Error parsing server event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class ContextWindowCompressionConfig(BaseModel):
|
||||
"""Configuration for context window compression.
|
||||
|
||||
Parameters:
|
||||
sliding_window: Whether to use sliding window compression. Defaults to True.
|
||||
trigger_tokens: Token count threshold to trigger compression. Defaults to None.
|
||||
"""
|
||||
|
||||
sliding_window: Optional[bool] = Field(default=True)
|
||||
trigger_tokens: Optional[int] = Field(default=None)
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from google.genai.types import (
|
||||
EndSensitivity as _EndSensitivity,
|
||||
)
|
||||
from google.genai.types import (
|
||||
StartSensitivity as _StartSensitivity,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
# These aliases are just here for backward compatibility, since we used to
|
||||
# define public-facing StartSensitivity and EndSensitivity enums in this
|
||||
# module.
|
||||
StartSensitivity = _StartSensitivity
|
||||
EndSensitivity = _EndSensitivity
|
||||
|
||||
@@ -12,6 +12,7 @@ voice transcription, streaming responses, and tool usage.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -19,6 +20,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
@@ -28,7 +30,6 @@ from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputImageRawFrame,
|
||||
@@ -72,11 +73,33 @@ from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_gemini_live, traced_stt
|
||||
|
||||
from . import events
|
||||
from .file_api import GeminiFileAPI
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from google.genai import Client
|
||||
from google.genai.live import AsyncSession
|
||||
from google.genai.types import (
|
||||
AudioTranscriptionConfig,
|
||||
AutomaticActivityDetection,
|
||||
Blob,
|
||||
Content,
|
||||
ContextWindowCompressionConfig,
|
||||
EndSensitivity,
|
||||
FileData,
|
||||
FunctionResponse,
|
||||
GenerationConfig,
|
||||
GroundingMetadata,
|
||||
HttpOptions,
|
||||
LiveConnectConfig,
|
||||
LiveServerMessage,
|
||||
Modality,
|
||||
Part,
|
||||
RealtimeInputConfig,
|
||||
SlidingWindow,
|
||||
SpeechConfig,
|
||||
StartSensitivity,
|
||||
VoiceConfig,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
@@ -246,13 +269,13 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
|
||||
self.messages.append(message)
|
||||
logger.info(f"Added file reference to context: {file_uri}")
|
||||
|
||||
def get_messages_for_initializing_history(self):
|
||||
def get_messages_for_initializing_history(self) -> List[Content]:
|
||||
"""Get messages formatted for Gemini history initialization.
|
||||
|
||||
Returns:
|
||||
List of messages in Gemini format for conversation history.
|
||||
"""
|
||||
messages = []
|
||||
messages: List[Content] = []
|
||||
for item in self.messages:
|
||||
role = item.get("role")
|
||||
|
||||
@@ -263,29 +286,28 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
|
||||
role = "model"
|
||||
|
||||
content = item.get("content")
|
||||
parts = []
|
||||
parts: List[Part] = []
|
||||
if isinstance(content, str):
|
||||
parts = [{"text": content}]
|
||||
parts = [Part(text=content)]
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if part.get("type") == "text":
|
||||
parts.append({"text": part.get("text")})
|
||||
parts.append(Part(text=part.get("text")))
|
||||
elif part.get("type") == "file_data":
|
||||
file_data = part.get("file_data", {})
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"fileData": {
|
||||
"mimeType": file_data.get("mime_type"),
|
||||
"fileUri": file_data.get("file_uri"),
|
||||
}
|
||||
}
|
||||
Part(
|
||||
file_data=FileData(
|
||||
mime_type=file_data.get("mime_type"),
|
||||
file_uri=file_data.get("file_uri"),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unsupported content type: {str(part)[:80]}")
|
||||
else:
|
||||
logger.warning(f"Unsupported content type: {str(content)[:80]}")
|
||||
messages.append({"role": role, "parts": parts})
|
||||
messages.append(Content(role=role, parts=parts))
|
||||
return messages
|
||||
|
||||
|
||||
@@ -411,8 +433,8 @@ class GeminiVADParams(BaseModel):
|
||||
"""
|
||||
|
||||
disabled: Optional[bool] = Field(default=None)
|
||||
start_sensitivity: Optional[events.StartSensitivity] = Field(default=None)
|
||||
end_sensitivity: Optional[events.EndSensitivity] = Field(default=None)
|
||||
start_sensitivity: Optional[StartSensitivity] = Field(default=None)
|
||||
end_sensitivity: Optional[EndSensitivity] = Field(default=None)
|
||||
prefix_padding_ms: Optional[int] = Field(default=None)
|
||||
silence_duration_ms: Optional[int] = Field(default=None)
|
||||
|
||||
@@ -441,7 +463,7 @@ class InputParams(BaseModel):
|
||||
temperature: Sampling temperature (0.0-2.0). Defaults to None.
|
||||
top_k: Top-k sampling parameter. Must be >= 0. Defaults to None.
|
||||
top_p: Top-p sampling parameter (0.0-1.0). Defaults to None.
|
||||
modalities: Response modalities. Defaults to AUDIO.
|
||||
modalities: Response modalities. Defaults to "AUDIO".
|
||||
language: Language for generation. Defaults to EN_US.
|
||||
media_resolution: Media resolution setting. Defaults to UNSPECIFIED.
|
||||
vad: Voice activity detection parameters. Defaults to None.
|
||||
@@ -455,8 +477,8 @@ class InputParams(BaseModel):
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
modalities: Optional[GeminiMultimodalModalities] = Field(
|
||||
default=GeminiMultimodalModalities.AUDIO
|
||||
modalities: Optional[Union[List[GeminiMultimodalModalities], GeminiMultimodalModalities]] = (
|
||||
Field(default=GeminiMultimodalModalities.AUDIO)
|
||||
)
|
||||
language: Optional[Language] = Field(default=Language.EN_US)
|
||||
media_resolution: Optional[GeminiMediaResolution] = Field(
|
||||
@@ -492,6 +514,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
params: Optional[InputParams] = None,
|
||||
inference_on_context_initialization: bool = True,
|
||||
file_api_base_url: str = "https://generativelanguage.googleapis.com/v1beta/files",
|
||||
http_options: Optional[HttpOptions] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gemini Multimodal Live LLM service.
|
||||
@@ -509,6 +532,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
inference_on_context_initialization: Whether to generate a response when context
|
||||
is first set. Defaults to True.
|
||||
file_api_base_url: Base URL for the Gemini File API. Defaults to the official endpoint.
|
||||
http_options: HTTP options for the client.
|
||||
**kwargs: Additional arguments passed to parent LLMService.
|
||||
"""
|
||||
super().__init__(base_url=base_url, **kwargs)
|
||||
@@ -516,7 +540,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
params = params or InputParams()
|
||||
|
||||
self._last_sent_time = 0
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self.set_model_name(model)
|
||||
self._voice_id = voice_id
|
||||
@@ -530,12 +553,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self._audio_input_paused = start_audio_paused
|
||||
self._video_input_paused = start_video_paused
|
||||
self._context = None
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._create_client(api_key, http_options)
|
||||
self._session: AsyncSession = None
|
||||
self._connection_task = None
|
||||
|
||||
self._disconnecting = False
|
||||
self._api_session_ready = False
|
||||
self._run_llm_when_api_session_ready = False
|
||||
self._run_llm_when_session_ready = False
|
||||
|
||||
self._user_is_speaking = False
|
||||
self._bot_is_speaking = False
|
||||
@@ -578,6 +601,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self._search_result_buffer = ""
|
||||
self._accumulated_grounding_metadata = None
|
||||
|
||||
def _create_client(self, api_key: str, http_options: Optional[HttpOptions] = None):
|
||||
self._client = Client(api_key=api_key, http_options=http_options)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate usage metrics.
|
||||
|
||||
@@ -613,7 +639,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
"""
|
||||
self._video_input_paused = paused
|
||||
|
||||
def set_model_modalities(self, modalities: GeminiMultimodalModalities):
|
||||
def set_model_modalities(
|
||||
self, modalities: Union[List[GeminiMultimodalModalities], GeminiMultimodalModalities]
|
||||
):
|
||||
"""Set the model response modalities.
|
||||
|
||||
Args:
|
||||
@@ -656,7 +684,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish websocket connection.
|
||||
"""Start the service and establish connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame.
|
||||
@@ -701,10 +729,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
await self.start_ttfb_metrics()
|
||||
if self._needs_turn_complete_message:
|
||||
self._needs_turn_complete_message = False
|
||||
evt = events.ClientContentMessage.model_validate(
|
||||
{"clientContent": {"turnComplete": True}}
|
||||
)
|
||||
await self.send_client_event(evt)
|
||||
# NOTE: without this, the model ignores the context it's been
|
||||
# seeded with before the user started speaking
|
||||
await self._session.send_client_content(turn_complete=True)
|
||||
|
||||
#
|
||||
# frame processing
|
||||
@@ -768,6 +795,11 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
# ignore this frame. Use the serverContent.turnComplete API message
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
# NOTE: handling LLMMessagesAppendFrame here in the LLMService is
|
||||
# unusual - typically this would be handled in the user context
|
||||
# aggregator. Leaving this handling here so that user code that
|
||||
# uses this frame *without* a user context aggregator still works
|
||||
# (we have an example that does just that, actually).
|
||||
await self._create_single_response(frame.messages)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
@@ -776,199 +808,171 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
#
|
||||
# websocket communication
|
||||
#
|
||||
|
||||
async def send_client_event(self, event):
|
||||
"""Send a client event to the Gemini Live API.
|
||||
|
||||
Args:
|
||||
event: The event to send.
|
||||
"""
|
||||
await self._ws_send(event.model_dump(exclude_none=True))
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish WebSocket connection to Gemini Live API."""
|
||||
if self._websocket:
|
||||
# Here we assume that if we have a websocket, we are connected. We
|
||||
"""Establish client connection to Gemini Live API."""
|
||||
if self._session:
|
||||
# Here we assume that if we have a client, we are connected. We
|
||||
# handle disconnections in the send/recv code paths.
|
||||
return
|
||||
|
||||
logger.info("Connecting to Gemini service")
|
||||
try:
|
||||
logger.info(f"Connecting to wss://{self._base_url}")
|
||||
uri = f"wss://{self._base_url}?key={self._api_key}"
|
||||
self._websocket = await websocket_connect(uri=uri)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
# Handle modalities being specified as either a list or a single value
|
||||
modalities = self._settings["modalities"]
|
||||
if isinstance(modalities, GeminiMultimodalModalities):
|
||||
modalities = [modalities]
|
||||
|
||||
# Create the basic configuration
|
||||
config_data = {
|
||||
"setup": {
|
||||
"model": self._model_name,
|
||||
"generation_config": {
|
||||
"frequency_penalty": self._settings["frequency_penalty"],
|
||||
"max_output_tokens": self._settings["max_tokens"],
|
||||
"presence_penalty": self._settings["presence_penalty"],
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"response_modalities": self._settings["modalities"].value,
|
||||
"speech_config": {
|
||||
"voice_config": {
|
||||
"prebuilt_voice_config": {"voice_name": self._voice_id}
|
||||
},
|
||||
"language_code": self._settings["language"],
|
||||
},
|
||||
"media_resolution": self._settings["media_resolution"].value,
|
||||
},
|
||||
"input_audio_transcription": {},
|
||||
"output_audio_transcription": {},
|
||||
}
|
||||
}
|
||||
# Assemble basic configuration
|
||||
config = LiveConnectConfig(
|
||||
generation_config=GenerationConfig(
|
||||
frequency_penalty=self._settings["frequency_penalty"],
|
||||
max_output_tokens=self._settings["max_tokens"],
|
||||
presence_penalty=self._settings["presence_penalty"],
|
||||
temperature=self._settings["temperature"],
|
||||
top_k=self._settings["top_k"],
|
||||
top_p=self._settings["top_p"],
|
||||
response_modalities=[Modality(modality.value) for modality in modalities],
|
||||
speech_config=SpeechConfig(
|
||||
voice_config=VoiceConfig(
|
||||
prebuilt_voice_config={"voice_name": self._voice_id}
|
||||
),
|
||||
language_code=self._settings["language"],
|
||||
),
|
||||
media_resolution=self._settings["media_resolution"].value,
|
||||
),
|
||||
input_audio_transcription=AudioTranscriptionConfig(),
|
||||
output_audio_transcription=AudioTranscriptionConfig(),
|
||||
)
|
||||
|
||||
# Add context window compression if enabled
|
||||
# Add context window compression to configuration, if enabled
|
||||
if self._settings.get("context_window_compression", {}).get("enabled", False):
|
||||
compression_config = {}
|
||||
compression_config = ContextWindowCompressionConfig()
|
||||
|
||||
# Add sliding window (always true if compression is enabled)
|
||||
compression_config["sliding_window"] = {}
|
||||
compression_config.sliding_window = SlidingWindow()
|
||||
|
||||
# Add trigger_tokens if specified
|
||||
trigger_tokens = self._settings.get("context_window_compression", {}).get(
|
||||
"trigger_tokens"
|
||||
)
|
||||
if trigger_tokens is not None:
|
||||
compression_config["trigger_tokens"] = trigger_tokens
|
||||
compression_config.trigger_tokens = trigger_tokens
|
||||
|
||||
config_data["setup"]["context_window_compression"] = compression_config
|
||||
config.context_window_compression = compression_config
|
||||
|
||||
# Add VAD configuration if provided
|
||||
# Add VAD configuration to configuration, if provided
|
||||
if self._settings.get("vad"):
|
||||
vad_config = {}
|
||||
vad_config = AutomaticActivityDetection()
|
||||
vad_params = self._settings["vad"]
|
||||
has_vad_settings = False
|
||||
|
||||
# Only add parameters that are explicitly set
|
||||
if vad_params.disabled is not None:
|
||||
vad_config["disabled"] = vad_params.disabled
|
||||
vad_config.disabled = vad_params.disabled
|
||||
has_vad_settings = True
|
||||
|
||||
if vad_params.start_sensitivity:
|
||||
vad_config["start_of_speech_sensitivity"] = vad_params.start_sensitivity.value
|
||||
vad_config.start_of_speech_sensitivity = vad_params.start_sensitivity.value
|
||||
has_vad_settings = True
|
||||
|
||||
if vad_params.end_sensitivity:
|
||||
vad_config["end_of_speech_sensitivity"] = vad_params.end_sensitivity.value
|
||||
vad_config.end_of_speech_sensitivity = vad_params.end_sensitivity.value
|
||||
has_vad_settings = True
|
||||
|
||||
if vad_params.prefix_padding_ms is not None:
|
||||
vad_config["prefix_padding_ms"] = vad_params.prefix_padding_ms
|
||||
vad_config.prefix_padding_ms = vad_params.prefix_padding_ms
|
||||
has_vad_settings = True
|
||||
|
||||
if vad_params.silence_duration_ms is not None:
|
||||
vad_config["silence_duration_ms"] = vad_params.silence_duration_ms
|
||||
vad_config.silence_duration_ms = vad_params.silence_duration_ms
|
||||
has_vad_settings = True
|
||||
|
||||
# Only add automatic_activity_detection if we have VAD settings
|
||||
if vad_config:
|
||||
realtime_config = {"automatic_activity_detection": vad_config}
|
||||
if has_vad_settings:
|
||||
config.realtime_input_config = RealtimeInputConfig(
|
||||
automatic_activity_detection=vad_config
|
||||
)
|
||||
|
||||
config_data["setup"]["realtime_input_config"] = realtime_config
|
||||
|
||||
config = events.Config.model_validate(config_data)
|
||||
|
||||
# Add system instruction if available
|
||||
# Add system instruction to configuration, if provided
|
||||
system_instruction = self._system_instruction or ""
|
||||
if self._context and hasattr(self._context, "extract_system_instructions"):
|
||||
system_instruction += "\n" + self._context.extract_system_instructions()
|
||||
if system_instruction:
|
||||
logger.debug(f"Setting system instruction: {system_instruction}")
|
||||
config.setup.system_instruction = events.SystemInstruction(
|
||||
parts=[events.ContentPart(text=system_instruction)]
|
||||
)
|
||||
config.system_instruction = system_instruction
|
||||
|
||||
# Add tools if available
|
||||
# Add tools to configuration, if provided
|
||||
if self._tools:
|
||||
logger.debug(f"Gemini is configuring to use tools{self._tools}")
|
||||
config.setup.tools = self.get_llm_adapter().from_standard_tools(self._tools)
|
||||
logger.debug(f"Setting tools: {self._tools}")
|
||||
config.tools = self.get_llm_adapter().from_standard_tools(self._tools)
|
||||
|
||||
# Send the configuration
|
||||
await self.send_client_event(config)
|
||||
# Start the connection
|
||||
self._connection_task = self.create_task(self._connection_task_handler(config=config))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
|
||||
async def _connection_task_handler(self, config: LiveConnectConfig):
|
||||
async with self._client.aio.live.connect(model=self._model_name, config=config) as session:
|
||||
logger.info("Connected to Gemini service")
|
||||
|
||||
await self._handle_session_ready(session)
|
||||
|
||||
while True:
|
||||
try:
|
||||
turn = self._session.receive()
|
||||
async for message in turn:
|
||||
if message.server_content and message.server_content.model_turn:
|
||||
await self._handle_msg_model_turn(message)
|
||||
elif (
|
||||
message.server_content
|
||||
and message.server_content.turn_complete
|
||||
and message.usage_metadata
|
||||
):
|
||||
await self._handle_msg_turn_complete(message)
|
||||
await self._handle_msg_usage_metadata(message)
|
||||
elif message.server_content and message.server_content.input_transcription:
|
||||
await self._handle_msg_input_transcription(message)
|
||||
elif message.server_content and message.server_content.output_transcription:
|
||||
await self._handle_msg_output_transcription(message)
|
||||
elif message.server_content and message.server_content.grounding_metadata:
|
||||
await self._handle_msg_grounding_metadata(message)
|
||||
elif message.tool_call:
|
||||
await self._handle_msg_tool_call(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in receive loop: {type(e)}: {e}")
|
||||
break
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Gemini Live API and clean up resources."""
|
||||
logger.info("Disconnecting from Gemini service")
|
||||
try:
|
||||
self._disconnecting = True
|
||||
self._api_session_ready = False
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=1.0)
|
||||
self._receive_task = None
|
||||
if self._connection_task:
|
||||
await self.cancel_task(self._connection_task, timeout=1.0)
|
||||
self._connection_task = None
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
async def _ws_send(self, message):
|
||||
"""Send a message to the WebSocket connection."""
|
||||
# logger.debug(f"Sending message to websocket: {message}")
|
||||
try:
|
||||
if self._websocket:
|
||||
await self._websocket.send(json.dumps(message))
|
||||
except Exception as e:
|
||||
if self._disconnecting:
|
||||
return
|
||||
logger.error(f"Error sending message to websocket: {e}")
|
||||
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}", fatal=True))
|
||||
|
||||
#
|
||||
# inbound server event handling
|
||||
# todo: docs link here
|
||||
#
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
"""Handle incoming messages from the WebSocket connection."""
|
||||
async for message in self._websocket:
|
||||
evt = events.parse_server_event(message)
|
||||
# logger.debug(f"Received event: {message[:500]}")
|
||||
# logger.debug(f"Received event: {evt}")
|
||||
|
||||
if evt.setupComplete:
|
||||
await self._handle_evt_setup_complete(evt)
|
||||
elif evt.serverContent and evt.serverContent.modelTurn:
|
||||
await self._handle_evt_model_turn(evt)
|
||||
elif evt.serverContent and evt.serverContent.turnComplete and evt.usageMetadata:
|
||||
await self._handle_evt_turn_complete(evt)
|
||||
await self._handle_evt_usage_metadata(evt)
|
||||
elif evt.serverContent and evt.serverContent.inputTranscription:
|
||||
await self._handle_evt_input_transcription(evt)
|
||||
elif evt.serverContent and evt.serverContent.outputTranscription:
|
||||
await self._handle_evt_output_transcription(evt)
|
||||
elif evt.serverContent and evt.serverContent.groundingMetadata:
|
||||
await self._handle_evt_grounding_metadata(evt)
|
||||
elif evt.toolCall:
|
||||
await self._handle_evt_tool_call(evt)
|
||||
elif False: # !!! todo: error events?
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
|
||||
#
|
||||
#
|
||||
#
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
"""Send user audio frame to Gemini Live API."""
|
||||
if self._audio_input_paused:
|
||||
return
|
||||
|
||||
if not self._session:
|
||||
return
|
||||
|
||||
# Send all audio to Gemini
|
||||
evt = events.AudioInputMessage.from_raw_audio(frame.audio, frame.sample_rate)
|
||||
await self.send_client_event(evt)
|
||||
await self._session.send_realtime_input(
|
||||
audio=Blob(data=frame.audio, mime_type=f"audio/pcm;rate={frame.sample_rate}")
|
||||
)
|
||||
|
||||
# Manage a buffer of audio to use for transcription
|
||||
audio = frame.audio
|
||||
if self._user_is_speaking:
|
||||
@@ -993,27 +997,34 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
Args:
|
||||
text: The text to send as user input.
|
||||
"""
|
||||
evt = events.TextInputMessage.from_text(text)
|
||||
await self.send_client_event(evt)
|
||||
if not self._session:
|
||||
return
|
||||
|
||||
await self._session.send_realtime_input(text=text)
|
||||
|
||||
async def _send_user_video(self, frame):
|
||||
"""Send user video frame to Gemini Live API."""
|
||||
if self._video_input_paused:
|
||||
return
|
||||
|
||||
if not self._session:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
if now - self._last_sent_time < 1:
|
||||
return # Ignore if less than 1 second has passed
|
||||
|
||||
self._last_sent_time = now # Update last sent time
|
||||
logger.debug(f"Sending video frame to Gemini: {frame}")
|
||||
evt = events.VideoInputMessage.from_image_frame(frame)
|
||||
await self.send_client_event(evt)
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(frame.format, frame.size, frame.image).save(buffer, format="JPEG")
|
||||
data = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
await self._session.send_realtime_input(video=Blob(data=data, mime_type="image/jpeg"))
|
||||
|
||||
async def _create_initial_response(self):
|
||||
"""Create initial response based on context history."""
|
||||
if not self._api_session_ready:
|
||||
self._run_llm_when_api_session_ready = True
|
||||
if not self._session:
|
||||
self._run_llm_when_session_ready = True
|
||||
return
|
||||
|
||||
messages = self._context.get_messages_for_initializing_history()
|
||||
@@ -1024,70 +1035,31 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
evt = events.ClientContentMessage.model_validate(
|
||||
{
|
||||
"clientContent": {
|
||||
"turns": messages,
|
||||
"turnComplete": self._inference_on_context_initialization,
|
||||
}
|
||||
}
|
||||
await self._session.send_client_content(
|
||||
turns=messages, turn_complete=self._inference_on_context_initialization
|
||||
)
|
||||
await self.send_client_event(evt)
|
||||
|
||||
# If we're generating a response right away upon initializing
|
||||
# conversation history, set a flag saying that we need a turn complete
|
||||
# message when the user stops speaking.
|
||||
if not self._inference_on_context_initialization:
|
||||
self._needs_turn_complete_message = True
|
||||
|
||||
async def _create_single_response(self, messages_list):
|
||||
"""Create a single response from a list of messages."""
|
||||
# Refactor to combine this logic with same logic in GeminiMultimodalLiveContext
|
||||
messages = []
|
||||
for item in messages_list:
|
||||
role = item.get("role")
|
||||
# Create a throwaway context just for the purpose of getting messages
|
||||
# in the right format
|
||||
context = GeminiMultimodalLiveContext.upgrade(OpenAILLMContext(messages=messages_list))
|
||||
messages = context.get_messages_for_initializing_history()
|
||||
|
||||
if role == "system":
|
||||
continue
|
||||
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
content = item.get("content")
|
||||
parts = []
|
||||
if isinstance(content, str):
|
||||
parts = [{"text": content}]
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if part.get("type") == "text":
|
||||
parts.append({"text": part.get("text")})
|
||||
elif part.get("type") == "file_data":
|
||||
file_data = part.get("file_data", {})
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"fileData": {
|
||||
"mimeType": file_data.get("mime_type"),
|
||||
"fileUri": file_data.get("file_uri"),
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unsupported content type: {str(part)[:80]}")
|
||||
else:
|
||||
logger.warning(f"Unsupported content type: {str(content)[:80]}")
|
||||
messages.append({"role": role, "parts": parts})
|
||||
if not messages:
|
||||
return
|
||||
|
||||
logger.debug(f"Creating response: {messages}")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
evt = events.ClientContentMessage.model_validate(
|
||||
{
|
||||
"clientContent": {
|
||||
"turns": messages,
|
||||
"turnComplete": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
await self.send_client_event(evt)
|
||||
await self._session.send_client_content(turns=messages, turn_complete=True)
|
||||
|
||||
@traced_gemini_live(operation="llm_tool_result")
|
||||
async def _tool_result(self, tool_result_message):
|
||||
@@ -1097,37 +1069,22 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
id = tool_result_message.get("tool_call_id")
|
||||
name = tool_result_message.get("tool_call_name")
|
||||
result = json.loads(tool_result_message.get("content") or "")
|
||||
response_message = json.dumps(
|
||||
{
|
||||
"toolResponse": {
|
||||
"functionResponses": [
|
||||
{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"response": {
|
||||
"result": result,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
await self._websocket.send(response_message)
|
||||
# await self._websocket.send(json.dumps({"clientContent": {"turnComplete": True}}))
|
||||
response = FunctionResponse(name=name, id=id, response=result)
|
||||
await self._session.send_tool_response(function_responses=response)
|
||||
|
||||
@traced_gemini_live(operation="llm_setup")
|
||||
async def _handle_evt_setup_complete(self, evt):
|
||||
"""Handle the setup complete event."""
|
||||
# If this is our first context frame, run the LLM
|
||||
self._api_session_ready = True
|
||||
# Now that we've configured the session, we can run the LLM if we need to.
|
||||
if self._run_llm_when_api_session_ready:
|
||||
self._run_llm_when_api_session_ready = False
|
||||
async def _handle_session_ready(self, session: AsyncSession):
|
||||
"""Handle the session being ready."""
|
||||
self._session = session
|
||||
# If we were just waititng for the session to be ready to run the LLM,
|
||||
# do that now.
|
||||
if self._run_llm_when_session_ready:
|
||||
self._run_llm_when_session_ready = False
|
||||
await self._create_initial_response()
|
||||
|
||||
async def _handle_evt_model_turn(self, evt):
|
||||
"""Handle the model turn event."""
|
||||
part = evt.serverContent.modelTurn.parts[0]
|
||||
async def _handle_msg_model_turn(self, msg: LiveServerMessage):
|
||||
"""Handle the model turn message."""
|
||||
part = msg.server_content.model_turn.parts[0]
|
||||
if not part:
|
||||
return
|
||||
|
||||
@@ -1144,17 +1101,17 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
|
||||
# Check for grounding metadata in server content
|
||||
if evt.serverContent and evt.serverContent.groundingMetadata:
|
||||
self._accumulated_grounding_metadata = evt.serverContent.groundingMetadata
|
||||
if msg.server_content and msg.server_content.grounding_metadata:
|
||||
self._accumulated_grounding_metadata = msg.server_content.grounding_metadata
|
||||
|
||||
inline_data = part.inlineData
|
||||
inline_data = part.inline_data
|
||||
if not inline_data:
|
||||
return
|
||||
if inline_data.mimeType != f"audio/pcm;rate={self._sample_rate}":
|
||||
logger.warning(f"Unrecognized server_content format {inline_data.mimeType}")
|
||||
if inline_data.mime_type != f"audio/pcm;rate={self._sample_rate}":
|
||||
logger.warning(f"Unrecognized server_content format {inline_data.mime_type}")
|
||||
return
|
||||
|
||||
audio = base64.b64decode(inline_data.data)
|
||||
audio = inline_data.data
|
||||
if not audio:
|
||||
return
|
||||
|
||||
@@ -1172,9 +1129,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
await self.push_frame(frame)
|
||||
|
||||
@traced_gemini_live(operation="llm_tool_call")
|
||||
async def _handle_evt_tool_call(self, evt):
|
||||
"""Handle tool call events."""
|
||||
function_calls = evt.toolCall.functionCalls
|
||||
async def _handle_msg_tool_call(self, message: LiveServerMessage):
|
||||
"""Handle tool call messages."""
|
||||
function_calls = message.tool_call.function_calls
|
||||
if not function_calls:
|
||||
return
|
||||
if not self._context:
|
||||
@@ -1193,12 +1150,13 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
await self.run_function_calls(function_calls_llm)
|
||||
|
||||
@traced_gemini_live(operation="llm_response")
|
||||
async def _handle_evt_turn_complete(self, evt):
|
||||
"""Handle the turn complete event."""
|
||||
async def _handle_msg_turn_complete(self, message: LiveServerMessage):
|
||||
"""Handle the turn complete message."""
|
||||
self._bot_is_speaking = False
|
||||
text = self._bot_text_buffer
|
||||
|
||||
# Determine output and modality for tracing
|
||||
# TODO: looks like there's a bug here - output_text and output_modality are unused
|
||||
if text:
|
||||
# TEXT modality
|
||||
output_text = text
|
||||
@@ -1239,17 +1197,17 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def _handle_evt_input_transcription(self, evt):
|
||||
"""Handle the input transcription event.
|
||||
async def _handle_msg_input_transcription(self, message: LiveServerMessage):
|
||||
"""Handle the input transcription message.
|
||||
|
||||
Gemini Live sends user transcriptions in either single words or multi-word
|
||||
phrases. As a result, we have to aggregate the input transcription. This handler
|
||||
aggregates into sentences, splitting on the end of sentence markers.
|
||||
"""
|
||||
if not evt.serverContent.inputTranscription:
|
||||
if not message.server_content.input_transcription:
|
||||
return
|
||||
|
||||
text = evt.serverContent.inputTranscription.text
|
||||
text = message.server_content.input_transcription.text
|
||||
|
||||
if not text:
|
||||
return
|
||||
@@ -1282,20 +1240,20 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
text=complete_sentence,
|
||||
user_id="",
|
||||
timestamp=time_now_iso8601(),
|
||||
result=evt,
|
||||
result=message,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
async def _handle_evt_output_transcription(self, evt):
|
||||
"""Handle the output transcription event."""
|
||||
if not evt.serverContent.outputTranscription:
|
||||
async def _handle_msg_output_transcription(self, message: LiveServerMessage):
|
||||
"""Handle the output transcription message."""
|
||||
if not message.server_content.output_transcription:
|
||||
return
|
||||
|
||||
# This is the output transcription text when modalities is set to AUDIO.
|
||||
# In this case, we push LLMTextFrame and TTSTextFrame to be handled by the
|
||||
# downstream assistant context aggregator.
|
||||
text = evt.serverContent.outputTranscription.text
|
||||
text = message.server_content.output_transcription.text
|
||||
|
||||
if not text:
|
||||
return
|
||||
@@ -1304,23 +1262,23 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self._search_result_buffer += text
|
||||
|
||||
# Check for grounding metadata in server content
|
||||
if evt.serverContent and evt.serverContent.groundingMetadata:
|
||||
self._accumulated_grounding_metadata = evt.serverContent.groundingMetadata
|
||||
if message.server_content and message.server_content.grounding_metadata:
|
||||
self._accumulated_grounding_metadata = message.server_content.grounding_metadata
|
||||
# Collect text for tracing
|
||||
self._llm_output_buffer += text
|
||||
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
|
||||
async def _handle_evt_grounding_metadata(self, evt):
|
||||
"""Handle dedicated grounding metadata events."""
|
||||
if evt.serverContent and evt.serverContent.groundingMetadata:
|
||||
grounding_metadata = evt.serverContent.groundingMetadata
|
||||
async def _handle_msg_grounding_metadata(self, message: LiveServerMessage):
|
||||
"""Handle dedicated grounding metadata messages."""
|
||||
if message.server_content and message.server_content.grounding_metadata:
|
||||
grounding_metadata = message.server_content.grounding_metadata
|
||||
# Process the grounding metadata immediately
|
||||
await self._process_grounding_metadata(grounding_metadata, self._search_result_buffer)
|
||||
|
||||
async def _process_grounding_metadata(
|
||||
self, grounding_metadata: events.GroundingMetadata, search_result: str = ""
|
||||
self, grounding_metadata: GroundingMetadata, search_result: str = ""
|
||||
):
|
||||
"""Process grounding metadata and emit LLMSearchResponseFrame."""
|
||||
if not grounding_metadata:
|
||||
@@ -1329,19 +1287,19 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
# Extract rendered content for search suggestions
|
||||
rendered_content = None
|
||||
if (
|
||||
grounding_metadata.searchEntryPoint
|
||||
and grounding_metadata.searchEntryPoint.renderedContent
|
||||
grounding_metadata.search_entry_point
|
||||
and grounding_metadata.search_entry_point.rendered_content
|
||||
):
|
||||
rendered_content = grounding_metadata.searchEntryPoint.renderedContent
|
||||
rendered_content = grounding_metadata.search_entry_point.rendered_content
|
||||
|
||||
# Convert grounding chunks and supports to LLMSearchOrigin format
|
||||
origins = []
|
||||
|
||||
if grounding_metadata.groundingChunks and grounding_metadata.groundingSupports:
|
||||
if grounding_metadata.grounding_chunks and grounding_metadata.grounding_supports:
|
||||
# Create a mapping of chunk indices to origins
|
||||
chunk_to_origin = {}
|
||||
chunk_to_origin: Dict[int, LLMSearchOrigin] = {}
|
||||
|
||||
for index, chunk in enumerate(grounding_metadata.groundingChunks):
|
||||
for index, chunk in enumerate(grounding_metadata.grounding_chunks):
|
||||
if chunk.web:
|
||||
origin = LLMSearchOrigin(
|
||||
site_uri=chunk.web.uri, site_title=chunk.web.title, results=[]
|
||||
@@ -1350,13 +1308,13 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
origins.append(origin)
|
||||
|
||||
# Add grounding support results to the appropriate origins
|
||||
for support in grounding_metadata.groundingSupports:
|
||||
if support.segment and support.groundingChunkIndices:
|
||||
for support in grounding_metadata.grounding_supports:
|
||||
if support.segment and support.grounding_chunk_indices:
|
||||
text = support.segment.text or ""
|
||||
confidence_scores = support.confidenceScores or []
|
||||
confidence_scores = support.confidence_scores or []
|
||||
|
||||
# Add this result to all origins referenced by this support
|
||||
for chunk_index in support.groundingChunkIndices:
|
||||
for chunk_index in support.grounding_chunk_indices:
|
||||
if chunk_index in chunk_to_origin:
|
||||
result = LLMSearchResult(text=text, confidence=confidence_scores)
|
||||
chunk_to_origin[chunk_index].results.append(result)
|
||||
@@ -1368,17 +1326,17 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
|
||||
await self.push_frame(search_frame)
|
||||
|
||||
async def _handle_evt_usage_metadata(self, evt):
|
||||
"""Handle the usage metadata event."""
|
||||
if not evt.usageMetadata:
|
||||
async def _handle_msg_usage_metadata(self, message: LiveServerMessage):
|
||||
"""Handle the usage metadata message."""
|
||||
if not message.usage_metadata:
|
||||
return
|
||||
|
||||
usage = evt.usageMetadata
|
||||
usage = message.usage_metadata
|
||||
|
||||
# Ensure we have valid integers for all token counts
|
||||
prompt_tokens = usage.promptTokenCount or 0
|
||||
completion_tokens = usage.responseTokenCount or 0
|
||||
total_tokens = usage.totalTokenCount or (prompt_tokens + completion_tokens)
|
||||
prompt_tokens = usage.prompt_token_count or 0
|
||||
completion_tokens = usage.response_token_count or 0
|
||||
total_tokens = usage.total_token_count or (prompt_tokens + completion_tokens)
|
||||
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
||||
Reference in New Issue
Block a user