Update GeminiMultimodalLiveLLMService to use the google-genai library, which is the new recommended approach for interfacing with Gemini Live.

This commit is contained in:
Paul Kompfner
2025-10-01 15:31:50 -04:00
parent 8293347b77
commit d87b6189ba
2 changed files with 269 additions and 814 deletions

View File

@@ -6,525 +6,22 @@
"""Event models and utilities for Google Gemini Multimodal Live API."""
import base64
import io
import json
from enum import Enum
from typing import List, Literal, Optional
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.frames.frames import ImageRawFrame
#
# Client events
#
class MediaChunk(BaseModel):
"""Represents a chunk of media data for transmission.
Parameters:
mimeType: MIME type of the media content.
data: Base64-encoded media data.
"""
mimeType: str
data: str
class ContentPart(BaseModel):
"""Represents a part of content that can contain text or media.
Parameters:
text: Text content. Defaults to None.
inlineData: Inline media data. Defaults to None.
"""
text: Optional[str] = Field(default=None, validate_default=False)
inlineData: Optional[MediaChunk] = Field(default=None, validate_default=False)
fileData: Optional["FileData"] = Field(default=None, validate_default=False)
class FileData(BaseModel):
"""Represents a file reference in the Gemini File API."""
mimeType: str
fileUri: str
ContentPart.model_rebuild() # Rebuild model to resolve forward reference
class Turn(BaseModel):
"""Represents a conversational turn in the dialogue.
Parameters:
role: The role of the speaker, either "user" or "model". Defaults to "user".
parts: List of content parts that make up the turn.
"""
role: Literal["user", "model"] = "user"
parts: List[ContentPart]
class StartSensitivity(str, Enum):
"""Determines how start of speech is detected."""
UNSPECIFIED = "START_SENSITIVITY_UNSPECIFIED" # Default is HIGH
HIGH = "START_SENSITIVITY_HIGH" # Detect start of speech more often
LOW = "START_SENSITIVITY_LOW" # Detect start of speech less often
class EndSensitivity(str, Enum):
"""Determines how end of speech is detected."""
UNSPECIFIED = "END_SENSITIVITY_UNSPECIFIED" # Default is HIGH
HIGH = "END_SENSITIVITY_HIGH" # End speech more often
LOW = "END_SENSITIVITY_LOW" # End speech less often
class AutomaticActivityDetection(BaseModel):
"""Configures automatic detection of voice activity.
Parameters:
disabled: Whether automatic activity detection is disabled. Defaults to None.
start_of_speech_sensitivity: Sensitivity for detecting speech start. Defaults to None.
prefix_padding_ms: Padding before speech start in milliseconds. Defaults to None.
end_of_speech_sensitivity: Sensitivity for detecting speech end. Defaults to None.
silence_duration_ms: Duration of silence to detect speech end. Defaults to None.
"""
disabled: Optional[bool] = None
start_of_speech_sensitivity: Optional[StartSensitivity] = None
prefix_padding_ms: Optional[int] = None
end_of_speech_sensitivity: Optional[EndSensitivity] = None
silence_duration_ms: Optional[int] = None
class RealtimeInputConfig(BaseModel):
"""Configures the realtime input behavior.
Parameters:
automatic_activity_detection: Voice activity detection configuration. Defaults to None.
"""
automatic_activity_detection: Optional[AutomaticActivityDetection] = None
class RealtimeInput(BaseModel):
"""Contains realtime input media chunks and text.
Parameters:
mediaChunks: List of media chunks for realtime processing.
text: Text for realtime processing.
"""
mediaChunks: Optional[List[MediaChunk]] = None
text: Optional[str] = None
class ClientContent(BaseModel):
"""Content sent from client to the Gemini Live API.
Parameters:
turns: List of conversation turns. Defaults to None.
turnComplete: Whether the client's turn is complete. Defaults to False.
"""
turns: Optional[List[Turn]] = None
turnComplete: bool = False
class AudioInputMessage(BaseModel):
"""Message containing audio input data.
Parameters:
realtimeInput: Realtime input containing audio chunks.
"""
realtimeInput: RealtimeInput
@classmethod
def from_raw_audio(cls, raw_audio: bytes, sample_rate: int) -> "AudioInputMessage":
"""Create an audio input message from raw audio data.
Args:
raw_audio: Raw audio bytes.
sample_rate: Audio sample rate in Hz.
Returns:
AudioInputMessage instance with encoded audio data.
"""
data = base64.b64encode(raw_audio).decode("utf-8")
return cls(
realtimeInput=RealtimeInput(
mediaChunks=[MediaChunk(mimeType=f"audio/pcm;rate={sample_rate}", data=data)]
)
)
class VideoInputMessage(BaseModel):
"""Message containing video/image input data.
Parameters:
realtimeInput: Realtime input containing video/image chunks.
"""
realtimeInput: RealtimeInput
@classmethod
def from_image_frame(cls, frame: ImageRawFrame) -> "VideoInputMessage":
"""Create a video input message from an image frame.
Args:
frame: Image frame to encode.
Returns:
VideoInputMessage instance with encoded image data.
"""
buffer = io.BytesIO()
Image.frombytes(frame.format, frame.size, frame.image).save(buffer, format="JPEG")
data = base64.b64encode(buffer.getvalue()).decode("utf-8")
return cls(
realtimeInput=RealtimeInput(mediaChunks=[MediaChunk(mimeType=f"image/jpeg", data=data)])
)
class TextInputMessage(BaseModel):
"""Message containing text input data."""
realtimeInput: RealtimeInput
@classmethod
def from_text(cls, text: str) -> "TextInputMessage":
"""Create a text input message from a string.
Args:
text: The text to send.
Returns:
A TextInputMessage instance.
"""
return cls(realtimeInput=RealtimeInput(text=text))
class ClientContentMessage(BaseModel):
"""Message containing client content for the API.
Parameters:
clientContent: The client content to send.
"""
clientContent: ClientContent
class SystemInstruction(BaseModel):
"""System instruction for the model.
Parameters:
parts: List of content parts that make up the system instruction.
"""
parts: List[ContentPart]
class AudioTranscriptionConfig(BaseModel):
"""Configuration for audio transcription."""
pass
class Setup(BaseModel):
"""Setup configuration for the Gemini Live session.
Parameters:
model: Model identifier to use.
system_instruction: System instruction for the model. Defaults to None.
tools: List of available tools/functions. Defaults to None.
generation_config: Generation configuration parameters. Defaults to None.
input_audio_transcription: Input audio transcription config. Defaults to None.
output_audio_transcription: Output audio transcription config. Defaults to None.
realtime_input_config: Realtime input configuration. Defaults to None.
"""
model: str
system_instruction: Optional[SystemInstruction] = None
tools: Optional[List[dict]] = None
generation_config: Optional[dict] = None
input_audio_transcription: Optional[AudioTranscriptionConfig] = None
output_audio_transcription: Optional[AudioTranscriptionConfig] = None
realtime_input_config: Optional[RealtimeInputConfig] = None
class Config(BaseModel):
"""Configuration message for session setup.
Parameters:
setup: Setup configuration for the session.
"""
setup: Setup
#
# Grounding metadata models
#
class SearchEntryPoint(BaseModel):
"""Represents the search entry point with rendered content for search suggestions."""
renderedContent: Optional[str] = None
class WebSource(BaseModel):
"""Represents a web source from grounding chunks."""
uri: Optional[str] = None
title: Optional[str] = None
class GroundingChunk(BaseModel):
"""Represents a grounding chunk containing web source information."""
web: Optional[WebSource] = None
class GroundingSegment(BaseModel):
"""Represents a segment of text that is grounded."""
startIndex: Optional[int] = None
endIndex: Optional[int] = None
text: Optional[str] = None
class GroundingSupport(BaseModel):
"""Represents support information for grounded text segments."""
segment: Optional[GroundingSegment] = None
groundingChunkIndices: Optional[List[int]] = None
confidenceScores: Optional[List[float]] = None
class GroundingMetadata(BaseModel):
"""Represents grounding metadata from Google Search."""
searchEntryPoint: Optional[SearchEntryPoint] = None
groundingChunks: Optional[List[GroundingChunk]] = None
groundingSupports: Optional[List[GroundingSupport]] = None
webSearchQueries: Optional[List[str]] = None
#
# Server events
#
class SetupComplete(BaseModel):
"""Indicates that session setup is complete."""
pass
class InlineData(BaseModel):
"""Inline data embedded in server responses.
Parameters:
mimeType: MIME type of the data.
data: Base64-encoded data content.
"""
mimeType: str
data: str
class Part(BaseModel):
"""Part of a server response containing data or text.
Parameters:
inlineData: Inline binary data. Defaults to None.
text: Text content. Defaults to None.
"""
inlineData: Optional[InlineData] = None
text: Optional[str] = None
class ModelTurn(BaseModel):
"""Represents a turn from the model in the conversation.
Parameters:
parts: List of content parts in the model's response.
"""
parts: List[Part]
class ServerContentInterrupted(BaseModel):
"""Indicates server content was interrupted.
Parameters:
interrupted: Whether the content was interrupted.
"""
interrupted: bool
class ServerContentTurnComplete(BaseModel):
"""Indicates the server's turn is complete.
Parameters:
turnComplete: Whether the turn is complete.
"""
turnComplete: bool
class BidiGenerateContentTranscription(BaseModel):
"""Transcription data from bidirectional content generation.
Parameters:
text: The transcribed text content.
"""
text: str
class ServerContent(BaseModel):
"""Content sent from server to client.
Parameters:
modelTurn: Model's conversational turn. Defaults to None.
interrupted: Whether content was interrupted. Defaults to None.
turnComplete: Whether the turn is complete. Defaults to None.
inputTranscription: Transcription of input audio. Defaults to None.
outputTranscription: Transcription of output audio. Defaults to None.
"""
modelTurn: Optional[ModelTurn] = None
interrupted: Optional[bool] = None
turnComplete: Optional[bool] = None
inputTranscription: Optional[BidiGenerateContentTranscription] = None
outputTranscription: Optional[BidiGenerateContentTranscription] = None
groundingMetadata: Optional[GroundingMetadata] = None
class FunctionCall(BaseModel):
"""Represents a function call from the model.
Parameters:
id: Unique identifier for the function call.
name: Name of the function to call.
args: Arguments to pass to the function.
"""
id: str
name: str
args: dict
class ToolCall(BaseModel):
"""Contains one or more function calls.
Parameters:
functionCalls: List of function calls to execute.
"""
functionCalls: List[FunctionCall]
class Modality(str, Enum):
"""Modality types in token counts."""
UNSPECIFIED = "MODALITY_UNSPECIFIED"
TEXT = "TEXT"
IMAGE = "IMAGE"
AUDIO = "AUDIO"
VIDEO = "VIDEO"
class ModalityTokenCount(BaseModel):
"""Token count for a specific modality.
Parameters:
modality: The modality type.
tokenCount: Number of tokens for this modality.
"""
modality: Modality
tokenCount: int
class UsageMetadata(BaseModel):
"""Usage metadata about the API response.
Parameters:
promptTokenCount: Number of tokens in the prompt. Defaults to None.
cachedContentTokenCount: Number of cached content tokens. Defaults to None.
responseTokenCount: Number of tokens in the response. Defaults to None.
toolUsePromptTokenCount: Number of tokens for tool use prompts. Defaults to None.
thoughtsTokenCount: Number of tokens for model thoughts. Defaults to None.
totalTokenCount: Total number of tokens used. Defaults to None.
promptTokensDetails: Detailed breakdown of prompt tokens by modality. Defaults to None.
cacheTokensDetails: Detailed breakdown of cache tokens by modality. Defaults to None.
responseTokensDetails: Detailed breakdown of response tokens by modality. Defaults to None.
toolUsePromptTokensDetails: Detailed breakdown of tool use tokens by modality. Defaults to None.
"""
promptTokenCount: Optional[int] = None
cachedContentTokenCount: Optional[int] = None
responseTokenCount: Optional[int] = None
toolUsePromptTokenCount: Optional[int] = None
thoughtsTokenCount: Optional[int] = None
totalTokenCount: Optional[int] = None
promptTokensDetails: Optional[List[ModalityTokenCount]] = None
cacheTokensDetails: Optional[List[ModalityTokenCount]] = None
responseTokensDetails: Optional[List[ModalityTokenCount]] = None
toolUsePromptTokensDetails: Optional[List[ModalityTokenCount]] = None
class ServerEvent(BaseModel):
"""Server event received from the Gemini Live API.
Parameters:
setupComplete: Setup completion notification. Defaults to None.
serverContent: Content from the server. Defaults to None.
toolCall: Tool/function call request. Defaults to None.
usageMetadata: Token usage metadata. Defaults to None.
"""
setupComplete: Optional[SetupComplete] = None
serverContent: Optional[ServerContent] = None
toolCall: Optional[ToolCall] = None
usageMetadata: Optional[UsageMetadata] = None
def parse_server_event(str):
"""Parse a server event from JSON string.
Args:
str: JSON string containing the server event.
Returns:
ServerEvent instance if parsing succeeds, None otherwise.
"""
try:
evt = json.loads(str)
return ServerEvent.model_validate(evt)
except Exception as e:
print(f"Error parsing server event: {e}")
return None
class ContextWindowCompressionConfig(BaseModel):
"""Configuration for context window compression.
Parameters:
sliding_window: Whether to use sliding window compression. Defaults to True.
trigger_tokens: Token count threshold to trigger compression. Defaults to None.
"""
sliding_window: Optional[bool] = Field(default=True)
trigger_tokens: Optional[int] = Field(default=None)
from loguru import logger
try:
from google.genai.types import (
EndSensitivity as _EndSensitivity,
)
from google.genai.types import (
StartSensitivity as _StartSensitivity,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
raise Exception(f"Missing module: {e}")
# These aliases are just here for backward compatibility, since we used to
# define public-facing StartSensitivity and EndSensitivity enums in this
# module.
StartSensitivity = _StartSensitivity
EndSensitivity = _EndSensitivity

View File

@@ -12,6 +12,7 @@ voice transcription, streaming responses, and tool usage.
"""
import base64
import io
import json
import time
from dataclasses import dataclass
@@ -19,6 +20,7 @@ from enum import Enum
from typing import Any, Dict, List, Optional, Union
from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.adapters.schemas.tools_schema import ToolsSchema
@@ -28,7 +30,6 @@ from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InputAudioRawFrame,
InputImageRawFrame,
@@ -72,11 +73,33 @@ from pipecat.utils.string import match_endofsentence
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_gemini_live, traced_stt
from . import events
from .file_api import GeminiFileAPI
try:
from websockets.asyncio.client import connect as websocket_connect
from google.genai import Client
from google.genai.live import AsyncSession
from google.genai.types import (
AudioTranscriptionConfig,
AutomaticActivityDetection,
Blob,
Content,
ContextWindowCompressionConfig,
EndSensitivity,
FileData,
FunctionResponse,
GenerationConfig,
GroundingMetadata,
HttpOptions,
LiveConnectConfig,
LiveServerMessage,
Modality,
Part,
RealtimeInputConfig,
SlidingWindow,
SpeechConfig,
StartSensitivity,
VoiceConfig,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
@@ -246,13 +269,13 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
self.messages.append(message)
logger.info(f"Added file reference to context: {file_uri}")
def get_messages_for_initializing_history(self):
def get_messages_for_initializing_history(self) -> List[Content]:
"""Get messages formatted for Gemini history initialization.
Returns:
List of messages in Gemini format for conversation history.
"""
messages = []
messages: List[Content] = []
for item in self.messages:
role = item.get("role")
@@ -263,29 +286,28 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
role = "model"
content = item.get("content")
parts = []
parts: List[Part] = []
if isinstance(content, str):
parts = [{"text": content}]
parts = [Part(text=content)]
elif isinstance(content, list):
for part in content:
if part.get("type") == "text":
parts.append({"text": part.get("text")})
parts.append(Part(text=part.get("text")))
elif part.get("type") == "file_data":
file_data = part.get("file_data", {})
parts.append(
{
"fileData": {
"mimeType": file_data.get("mime_type"),
"fileUri": file_data.get("file_uri"),
}
}
Part(
file_data=FileData(
mime_type=file_data.get("mime_type"),
file_uri=file_data.get("file_uri"),
)
)
)
else:
logger.warning(f"Unsupported content type: {str(part)[:80]}")
else:
logger.warning(f"Unsupported content type: {str(content)[:80]}")
messages.append({"role": role, "parts": parts})
messages.append(Content(role=role, parts=parts))
return messages
@@ -411,8 +433,8 @@ class GeminiVADParams(BaseModel):
"""
disabled: Optional[bool] = Field(default=None)
start_sensitivity: Optional[events.StartSensitivity] = Field(default=None)
end_sensitivity: Optional[events.EndSensitivity] = Field(default=None)
start_sensitivity: Optional[StartSensitivity] = Field(default=None)
end_sensitivity: Optional[EndSensitivity] = Field(default=None)
prefix_padding_ms: Optional[int] = Field(default=None)
silence_duration_ms: Optional[int] = Field(default=None)
@@ -441,7 +463,7 @@ class InputParams(BaseModel):
temperature: Sampling temperature (0.0-2.0). Defaults to None.
top_k: Top-k sampling parameter. Must be >= 0. Defaults to None.
top_p: Top-p sampling parameter (0.0-1.0). Defaults to None.
modalities: Response modalities. Defaults to AUDIO.
modalities: Response modalities. Defaults to "AUDIO".
language: Language for generation. Defaults to EN_US.
media_resolution: Media resolution setting. Defaults to UNSPECIFIED.
vad: Voice activity detection parameters. Defaults to None.
@@ -455,8 +477,8 @@ class InputParams(BaseModel):
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
modalities: Optional[GeminiMultimodalModalities] = Field(
default=GeminiMultimodalModalities.AUDIO
modalities: Optional[Union[List[GeminiMultimodalModalities], GeminiMultimodalModalities]] = (
Field(default=GeminiMultimodalModalities.AUDIO)
)
language: Optional[Language] = Field(default=Language.EN_US)
media_resolution: Optional[GeminiMediaResolution] = Field(
@@ -492,6 +514,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
params: Optional[InputParams] = None,
inference_on_context_initialization: bool = True,
file_api_base_url: str = "https://generativelanguage.googleapis.com/v1beta/files",
http_options: Optional[HttpOptions] = None,
**kwargs,
):
"""Initialize the Gemini Multimodal Live LLM service.
@@ -509,6 +532,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
inference_on_context_initialization: Whether to generate a response when context
is first set. Defaults to True.
file_api_base_url: Base URL for the Gemini File API. Defaults to the official endpoint.
http_options: HTTP options for the client.
**kwargs: Additional arguments passed to parent LLMService.
"""
super().__init__(base_url=base_url, **kwargs)
@@ -516,7 +540,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
params = params or InputParams()
self._last_sent_time = 0
self._api_key = api_key
self._base_url = base_url
self.set_model_name(model)
self._voice_id = voice_id
@@ -530,12 +553,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
self._audio_input_paused = start_audio_paused
self._video_input_paused = start_video_paused
self._context = None
self._websocket = None
self._receive_task = None
self._create_client(api_key, http_options)
self._session: AsyncSession = None
self._connection_task = None
self._disconnecting = False
self._api_session_ready = False
self._run_llm_when_api_session_ready = False
self._run_llm_when_session_ready = False
self._user_is_speaking = False
self._bot_is_speaking = False
@@ -578,6 +601,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
self._search_result_buffer = ""
self._accumulated_grounding_metadata = None
def _create_client(self, api_key: str, http_options: Optional[HttpOptions] = None):
self._client = Client(api_key=api_key, http_options=http_options)
def can_generate_metrics(self) -> bool:
"""Check if the service can generate usage metrics.
@@ -613,7 +639,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
"""
self._video_input_paused = paused
def set_model_modalities(self, modalities: GeminiMultimodalModalities):
def set_model_modalities(
self, modalities: Union[List[GeminiMultimodalModalities], GeminiMultimodalModalities]
):
"""Set the model response modalities.
Args:
@@ -656,7 +684,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def start(self, frame: StartFrame):
"""Start the service and establish websocket connection.
"""Start the service and establish connection.
Args:
frame: The start frame.
@@ -701,10 +729,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self.start_ttfb_metrics()
if self._needs_turn_complete_message:
self._needs_turn_complete_message = False
evt = events.ClientContentMessage.model_validate(
{"clientContent": {"turnComplete": True}}
)
await self.send_client_event(evt)
# NOTE: without this, the model ignores the context it's been
# seeded with before the user started speaking
await self._session.send_client_content(turn_complete=True)
#
# frame processing
@@ -768,6 +795,11 @@ class GeminiMultimodalLiveLLMService(LLMService):
# ignore this frame. Use the serverContent.turnComplete API message
await self.push_frame(frame, direction)
elif isinstance(frame, LLMMessagesAppendFrame):
# NOTE: handling LLMMessagesAppendFrame here in the LLMService is
# unusual - typically this would be handled in the user context
# aggregator. Leaving this handling here so that user code that
# uses this frame *without* a user context aggregator still works
# (we have an example that does just that, actually).
await self._create_single_response(frame.messages)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
@@ -776,199 +808,171 @@ class GeminiMultimodalLiveLLMService(LLMService):
else:
await self.push_frame(frame, direction)
#
# websocket communication
#
async def send_client_event(self, event):
"""Send a client event to the Gemini Live API.
Args:
event: The event to send.
"""
await self._ws_send(event.model_dump(exclude_none=True))
async def _connect(self):
"""Establish WebSocket connection to Gemini Live API."""
if self._websocket:
# Here we assume that if we have a websocket, we are connected. We
"""Establish client connection to Gemini Live API."""
if self._session:
# Here we assume that if we have a client, we are connected. We
# handle disconnections in the send/recv code paths.
return
logger.info("Connecting to Gemini service")
try:
logger.info(f"Connecting to wss://{self._base_url}")
uri = f"wss://{self._base_url}?key={self._api_key}"
self._websocket = await websocket_connect(uri=uri)
self._receive_task = self.create_task(self._receive_task_handler())
# Handle modalities being specified as either a list or a single value
modalities = self._settings["modalities"]
if isinstance(modalities, GeminiMultimodalModalities):
modalities = [modalities]
# Create the basic configuration
config_data = {
"setup": {
"model": self._model_name,
"generation_config": {
"frequency_penalty": self._settings["frequency_penalty"],
"max_output_tokens": self._settings["max_tokens"],
"presence_penalty": self._settings["presence_penalty"],
"temperature": self._settings["temperature"],
"top_k": self._settings["top_k"],
"top_p": self._settings["top_p"],
"response_modalities": self._settings["modalities"].value,
"speech_config": {
"voice_config": {
"prebuilt_voice_config": {"voice_name": self._voice_id}
},
"language_code": self._settings["language"],
},
"media_resolution": self._settings["media_resolution"].value,
},
"input_audio_transcription": {},
"output_audio_transcription": {},
}
}
# Assemble basic configuration
config = LiveConnectConfig(
generation_config=GenerationConfig(
frequency_penalty=self._settings["frequency_penalty"],
max_output_tokens=self._settings["max_tokens"],
presence_penalty=self._settings["presence_penalty"],
temperature=self._settings["temperature"],
top_k=self._settings["top_k"],
top_p=self._settings["top_p"],
response_modalities=[Modality(modality.value) for modality in modalities],
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config={"voice_name": self._voice_id}
),
language_code=self._settings["language"],
),
media_resolution=self._settings["media_resolution"].value,
),
input_audio_transcription=AudioTranscriptionConfig(),
output_audio_transcription=AudioTranscriptionConfig(),
)
# Add context window compression if enabled
# Add context window compression to configuration, if enabled
if self._settings.get("context_window_compression", {}).get("enabled", False):
compression_config = {}
compression_config = ContextWindowCompressionConfig()
# Add sliding window (always true if compression is enabled)
compression_config["sliding_window"] = {}
compression_config.sliding_window = SlidingWindow()
# Add trigger_tokens if specified
trigger_tokens = self._settings.get("context_window_compression", {}).get(
"trigger_tokens"
)
if trigger_tokens is not None:
compression_config["trigger_tokens"] = trigger_tokens
compression_config.trigger_tokens = trigger_tokens
config_data["setup"]["context_window_compression"] = compression_config
config.context_window_compression = compression_config
# Add VAD configuration if provided
# Add VAD configuration to configuration, if provided
if self._settings.get("vad"):
vad_config = {}
vad_config = AutomaticActivityDetection()
vad_params = self._settings["vad"]
has_vad_settings = False
# Only add parameters that are explicitly set
if vad_params.disabled is not None:
vad_config["disabled"] = vad_params.disabled
vad_config.disabled = vad_params.disabled
has_vad_settings = True
if vad_params.start_sensitivity:
vad_config["start_of_speech_sensitivity"] = vad_params.start_sensitivity.value
vad_config.start_of_speech_sensitivity = vad_params.start_sensitivity.value
has_vad_settings = True
if vad_params.end_sensitivity:
vad_config["end_of_speech_sensitivity"] = vad_params.end_sensitivity.value
vad_config.end_of_speech_sensitivity = vad_params.end_sensitivity.value
has_vad_settings = True
if vad_params.prefix_padding_ms is not None:
vad_config["prefix_padding_ms"] = vad_params.prefix_padding_ms
vad_config.prefix_padding_ms = vad_params.prefix_padding_ms
has_vad_settings = True
if vad_params.silence_duration_ms is not None:
vad_config["silence_duration_ms"] = vad_params.silence_duration_ms
vad_config.silence_duration_ms = vad_params.silence_duration_ms
has_vad_settings = True
# Only add automatic_activity_detection if we have VAD settings
if vad_config:
realtime_config = {"automatic_activity_detection": vad_config}
if has_vad_settings:
config.realtime_input_config = RealtimeInputConfig(
automatic_activity_detection=vad_config
)
config_data["setup"]["realtime_input_config"] = realtime_config
config = events.Config.model_validate(config_data)
# Add system instruction if available
# Add system instruction to configuration, if provided
system_instruction = self._system_instruction or ""
if self._context and hasattr(self._context, "extract_system_instructions"):
system_instruction += "\n" + self._context.extract_system_instructions()
if system_instruction:
logger.debug(f"Setting system instruction: {system_instruction}")
config.setup.system_instruction = events.SystemInstruction(
parts=[events.ContentPart(text=system_instruction)]
)
config.system_instruction = system_instruction
# Add tools if available
# Add tools to configuration, if provided
if self._tools:
logger.debug(f"Gemini is configuring to use tools{self._tools}")
config.setup.tools = self.get_llm_adapter().from_standard_tools(self._tools)
logger.debug(f"Setting tools: {self._tools}")
config.tools = self.get_llm_adapter().from_standard_tools(self._tools)
# Send the configuration
await self.send_client_event(config)
# Start the connection
self._connection_task = self.create_task(self._connection_task_handler(config=config))
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
self._session = None
async def _connection_task_handler(self, config: LiveConnectConfig):
async with self._client.aio.live.connect(model=self._model_name, config=config) as session:
logger.info("Connected to Gemini service")
await self._handle_session_ready(session)
while True:
try:
turn = self._session.receive()
async for message in turn:
if message.server_content and message.server_content.model_turn:
await self._handle_msg_model_turn(message)
elif (
message.server_content
and message.server_content.turn_complete
and message.usage_metadata
):
await self._handle_msg_turn_complete(message)
await self._handle_msg_usage_metadata(message)
elif message.server_content and message.server_content.input_transcription:
await self._handle_msg_input_transcription(message)
elif message.server_content and message.server_content.output_transcription:
await self._handle_msg_output_transcription(message)
elif message.server_content and message.server_content.grounding_metadata:
await self._handle_msg_grounding_metadata(message)
elif message.tool_call:
await self._handle_msg_tool_call(message)
except Exception as e:
logger.error(f"Error in receive loop: {type(e)}: {e}")
break
async def _disconnect(self):
"""Disconnect from Gemini Live API and clean up resources."""
logger.info("Disconnecting from Gemini service")
try:
self._disconnecting = True
self._api_session_ready = False
await self.stop_all_metrics()
if self._websocket:
await self._websocket.close()
self._websocket = None
if self._receive_task:
await self.cancel_task(self._receive_task, timeout=1.0)
self._receive_task = None
if self._connection_task:
await self.cancel_task(self._connection_task, timeout=1.0)
self._connection_task = None
if self._session:
await self._session.close()
self._session = None
self._disconnecting = False
except Exception as e:
logger.error(f"{self} error disconnecting: {e}")
async def _ws_send(self, message):
"""Send a message to the WebSocket connection."""
# logger.debug(f"Sending message to websocket: {message}")
try:
if self._websocket:
await self._websocket.send(json.dumps(message))
except Exception as e:
if self._disconnecting:
return
logger.error(f"Error sending message to websocket: {e}")
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
# it is to recover from a send-side error with proper state management, and that exponential
# backoff for retries can have cost/stability implications for a service cluster, let's just
# treat a send-side error as fatal.
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}", fatal=True))
#
# inbound server event handling
# todo: docs link here
#
async def _receive_task_handler(self):
"""Handle incoming messages from the WebSocket connection."""
async for message in self._websocket:
evt = events.parse_server_event(message)
# logger.debug(f"Received event: {message[:500]}")
# logger.debug(f"Received event: {evt}")
if evt.setupComplete:
await self._handle_evt_setup_complete(evt)
elif evt.serverContent and evt.serverContent.modelTurn:
await self._handle_evt_model_turn(evt)
elif evt.serverContent and evt.serverContent.turnComplete and evt.usageMetadata:
await self._handle_evt_turn_complete(evt)
await self._handle_evt_usage_metadata(evt)
elif evt.serverContent and evt.serverContent.inputTranscription:
await self._handle_evt_input_transcription(evt)
elif evt.serverContent and evt.serverContent.outputTranscription:
await self._handle_evt_output_transcription(evt)
elif evt.serverContent and evt.serverContent.groundingMetadata:
await self._handle_evt_grounding_metadata(evt)
elif evt.toolCall:
await self._handle_evt_tool_call(evt)
elif False: # !!! todo: error events?
await self._handle_evt_error(evt)
# errors are fatal, so exit the receive loop
return
#
#
#
async def _send_user_audio(self, frame):
"""Send user audio frame to Gemini Live API."""
if self._audio_input_paused:
return
if not self._session:
return
# Send all audio to Gemini
evt = events.AudioInputMessage.from_raw_audio(frame.audio, frame.sample_rate)
await self.send_client_event(evt)
await self._session.send_realtime_input(
audio=Blob(data=frame.audio, mime_type=f"audio/pcm;rate={frame.sample_rate}")
)
# Manage a buffer of audio to use for transcription
audio = frame.audio
if self._user_is_speaking:
@@ -993,27 +997,34 @@ class GeminiMultimodalLiveLLMService(LLMService):
Args:
text: The text to send as user input.
"""
evt = events.TextInputMessage.from_text(text)
await self.send_client_event(evt)
if not self._session:
return
await self._session.send_realtime_input(text=text)
async def _send_user_video(self, frame):
"""Send user video frame to Gemini Live API."""
if self._video_input_paused:
return
if not self._session:
return
now = time.time()
if now - self._last_sent_time < 1:
return # Ignore if less than 1 second has passed
self._last_sent_time = now # Update last sent time
logger.debug(f"Sending video frame to Gemini: {frame}")
evt = events.VideoInputMessage.from_image_frame(frame)
await self.send_client_event(evt)
buffer = io.BytesIO()
Image.frombytes(frame.format, frame.size, frame.image).save(buffer, format="JPEG")
data = base64.b64encode(buffer.getvalue()).decode("utf-8")
await self._session.send_realtime_input(video=Blob(data=data, mime_type="image/jpeg"))
async def _create_initial_response(self):
"""Create initial response based on context history."""
if not self._api_session_ready:
self._run_llm_when_api_session_ready = True
if not self._session:
self._run_llm_when_session_ready = True
return
messages = self._context.get_messages_for_initializing_history()
@@ -1024,70 +1035,31 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self.start_ttfb_metrics()
evt = events.ClientContentMessage.model_validate(
{
"clientContent": {
"turns": messages,
"turnComplete": self._inference_on_context_initialization,
}
}
await self._session.send_client_content(
turns=messages, turn_complete=self._inference_on_context_initialization
)
await self.send_client_event(evt)
# If we're generating a response right away upon initializing
# conversation history, set a flag saying that we need a turn complete
# message when the user stops speaking.
if not self._inference_on_context_initialization:
self._needs_turn_complete_message = True
async def _create_single_response(self, messages_list):
"""Create a single response from a list of messages."""
# Refactor to combine this logic with same logic in GeminiMultimodalLiveContext
messages = []
for item in messages_list:
role = item.get("role")
# Create a throwaway context just for the purpose of getting messages
# in the right format
context = GeminiMultimodalLiveContext.upgrade(OpenAILLMContext(messages=messages_list))
messages = context.get_messages_for_initializing_history()
if role == "system":
continue
elif role == "assistant":
role = "model"
content = item.get("content")
parts = []
if isinstance(content, str):
parts = [{"text": content}]
elif isinstance(content, list):
for part in content:
if part.get("type") == "text":
parts.append({"text": part.get("text")})
elif part.get("type") == "file_data":
file_data = part.get("file_data", {})
parts.append(
{
"fileData": {
"mimeType": file_data.get("mime_type"),
"fileUri": file_data.get("file_uri"),
}
}
)
else:
logger.warning(f"Unsupported content type: {str(part)[:80]}")
else:
logger.warning(f"Unsupported content type: {str(content)[:80]}")
messages.append({"role": role, "parts": parts})
if not messages:
return
logger.debug(f"Creating response: {messages}")
await self.start_ttfb_metrics()
evt = events.ClientContentMessage.model_validate(
{
"clientContent": {
"turns": messages,
"turnComplete": True,
}
}
)
await self.send_client_event(evt)
await self._session.send_client_content(turns=messages, turn_complete=True)
@traced_gemini_live(operation="llm_tool_result")
async def _tool_result(self, tool_result_message):
@@ -1097,37 +1069,22 @@ class GeminiMultimodalLiveLLMService(LLMService):
id = tool_result_message.get("tool_call_id")
name = tool_result_message.get("tool_call_name")
result = json.loads(tool_result_message.get("content") or "")
response_message = json.dumps(
{
"toolResponse": {
"functionResponses": [
{
"id": id,
"name": name,
"response": {
"result": result,
},
}
],
}
}
)
await self._websocket.send(response_message)
# await self._websocket.send(json.dumps({"clientContent": {"turnComplete": True}}))
response = FunctionResponse(name=name, id=id, response=result)
await self._session.send_tool_response(function_responses=response)
@traced_gemini_live(operation="llm_setup")
async def _handle_evt_setup_complete(self, evt):
"""Handle the setup complete event."""
# If this is our first context frame, run the LLM
self._api_session_ready = True
# Now that we've configured the session, we can run the LLM if we need to.
if self._run_llm_when_api_session_ready:
self._run_llm_when_api_session_ready = False
async def _handle_session_ready(self, session: AsyncSession):
"""Handle the session being ready."""
self._session = session
# If we were just waititng for the session to be ready to run the LLM,
# do that now.
if self._run_llm_when_session_ready:
self._run_llm_when_session_ready = False
await self._create_initial_response()
async def _handle_evt_model_turn(self, evt):
"""Handle the model turn event."""
part = evt.serverContent.modelTurn.parts[0]
async def _handle_msg_model_turn(self, msg: LiveServerMessage):
"""Handle the model turn message."""
part = msg.server_content.model_turn.parts[0]
if not part:
return
@@ -1144,17 +1101,17 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self.push_frame(LLMTextFrame(text=text))
# Check for grounding metadata in server content
if evt.serverContent and evt.serverContent.groundingMetadata:
self._accumulated_grounding_metadata = evt.serverContent.groundingMetadata
if msg.server_content and msg.server_content.grounding_metadata:
self._accumulated_grounding_metadata = msg.server_content.grounding_metadata
inline_data = part.inlineData
inline_data = part.inline_data
if not inline_data:
return
if inline_data.mimeType != f"audio/pcm;rate={self._sample_rate}":
logger.warning(f"Unrecognized server_content format {inline_data.mimeType}")
if inline_data.mime_type != f"audio/pcm;rate={self._sample_rate}":
logger.warning(f"Unrecognized server_content format {inline_data.mime_type}")
return
audio = base64.b64decode(inline_data.data)
audio = inline_data.data
if not audio:
return
@@ -1172,9 +1129,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self.push_frame(frame)
@traced_gemini_live(operation="llm_tool_call")
async def _handle_evt_tool_call(self, evt):
"""Handle tool call events."""
function_calls = evt.toolCall.functionCalls
async def _handle_msg_tool_call(self, message: LiveServerMessage):
"""Handle tool call messages."""
function_calls = message.tool_call.function_calls
if not function_calls:
return
if not self._context:
@@ -1193,12 +1150,13 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self.run_function_calls(function_calls_llm)
@traced_gemini_live(operation="llm_response")
async def _handle_evt_turn_complete(self, evt):
"""Handle the turn complete event."""
async def _handle_msg_turn_complete(self, message: LiveServerMessage):
"""Handle the turn complete message."""
self._bot_is_speaking = False
text = self._bot_text_buffer
# Determine output and modality for tracing
# TODO: looks like there's a bug here - output_text and output_modality are unused
if text:
# TEXT modality
output_text = text
@@ -1239,17 +1197,17 @@ class GeminiMultimodalLiveLLMService(LLMService):
"""Handle a transcription result with tracing."""
pass
async def _handle_evt_input_transcription(self, evt):
"""Handle the input transcription event.
async def _handle_msg_input_transcription(self, message: LiveServerMessage):
"""Handle the input transcription message.
Gemini Live sends user transcriptions in either single words or multi-word
phrases. As a result, we have to aggregate the input transcription. This handler
aggregates into sentences, splitting on the end of sentence markers.
"""
if not evt.serverContent.inputTranscription:
if not message.server_content.input_transcription:
return
text = evt.serverContent.inputTranscription.text
text = message.server_content.input_transcription.text
if not text:
return
@@ -1282,20 +1240,20 @@ class GeminiMultimodalLiveLLMService(LLMService):
text=complete_sentence,
user_id="",
timestamp=time_now_iso8601(),
result=evt,
result=message,
),
FrameDirection.UPSTREAM,
)
async def _handle_evt_output_transcription(self, evt):
"""Handle the output transcription event."""
if not evt.serverContent.outputTranscription:
async def _handle_msg_output_transcription(self, message: LiveServerMessage):
"""Handle the output transcription message."""
if not message.server_content.output_transcription:
return
# This is the output transcription text when modalities is set to AUDIO.
# In this case, we push LLMTextFrame and TTSTextFrame to be handled by the
# downstream assistant context aggregator.
text = evt.serverContent.outputTranscription.text
text = message.server_content.output_transcription.text
if not text:
return
@@ -1304,23 +1262,23 @@ class GeminiMultimodalLiveLLMService(LLMService):
self._search_result_buffer += text
# Check for grounding metadata in server content
if evt.serverContent and evt.serverContent.groundingMetadata:
self._accumulated_grounding_metadata = evt.serverContent.groundingMetadata
if message.server_content and message.server_content.grounding_metadata:
self._accumulated_grounding_metadata = message.server_content.grounding_metadata
# Collect text for tracing
self._llm_output_buffer += text
await self.push_frame(LLMTextFrame(text=text))
await self.push_frame(TTSTextFrame(text=text))
async def _handle_evt_grounding_metadata(self, evt):
"""Handle dedicated grounding metadata events."""
if evt.serverContent and evt.serverContent.groundingMetadata:
grounding_metadata = evt.serverContent.groundingMetadata
async def _handle_msg_grounding_metadata(self, message: LiveServerMessage):
"""Handle dedicated grounding metadata messages."""
if message.server_content and message.server_content.grounding_metadata:
grounding_metadata = message.server_content.grounding_metadata
# Process the grounding metadata immediately
await self._process_grounding_metadata(grounding_metadata, self._search_result_buffer)
async def _process_grounding_metadata(
self, grounding_metadata: events.GroundingMetadata, search_result: str = ""
self, grounding_metadata: GroundingMetadata, search_result: str = ""
):
"""Process grounding metadata and emit LLMSearchResponseFrame."""
if not grounding_metadata:
@@ -1329,19 +1287,19 @@ class GeminiMultimodalLiveLLMService(LLMService):
# Extract rendered content for search suggestions
rendered_content = None
if (
grounding_metadata.searchEntryPoint
and grounding_metadata.searchEntryPoint.renderedContent
grounding_metadata.search_entry_point
and grounding_metadata.search_entry_point.rendered_content
):
rendered_content = grounding_metadata.searchEntryPoint.renderedContent
rendered_content = grounding_metadata.search_entry_point.rendered_content
# Convert grounding chunks and supports to LLMSearchOrigin format
origins = []
if grounding_metadata.groundingChunks and grounding_metadata.groundingSupports:
if grounding_metadata.grounding_chunks and grounding_metadata.grounding_supports:
# Create a mapping of chunk indices to origins
chunk_to_origin = {}
chunk_to_origin: Dict[int, LLMSearchOrigin] = {}
for index, chunk in enumerate(grounding_metadata.groundingChunks):
for index, chunk in enumerate(grounding_metadata.grounding_chunks):
if chunk.web:
origin = LLMSearchOrigin(
site_uri=chunk.web.uri, site_title=chunk.web.title, results=[]
@@ -1350,13 +1308,13 @@ class GeminiMultimodalLiveLLMService(LLMService):
origins.append(origin)
# Add grounding support results to the appropriate origins
for support in grounding_metadata.groundingSupports:
if support.segment and support.groundingChunkIndices:
for support in grounding_metadata.grounding_supports:
if support.segment and support.grounding_chunk_indices:
text = support.segment.text or ""
confidence_scores = support.confidenceScores or []
confidence_scores = support.confidence_scores or []
# Add this result to all origins referenced by this support
for chunk_index in support.groundingChunkIndices:
for chunk_index in support.grounding_chunk_indices:
if chunk_index in chunk_to_origin:
result = LLMSearchResult(text=text, confidence=confidence_scores)
chunk_to_origin[chunk_index].results.append(result)
@@ -1368,17 +1326,17 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self.push_frame(search_frame)
async def _handle_evt_usage_metadata(self, evt):
"""Handle the usage metadata event."""
if not evt.usageMetadata:
async def _handle_msg_usage_metadata(self, message: LiveServerMessage):
"""Handle the usage metadata message."""
if not message.usage_metadata:
return
usage = evt.usageMetadata
usage = message.usage_metadata
# Ensure we have valid integers for all token counts
prompt_tokens = usage.promptTokenCount or 0
completion_tokens = usage.responseTokenCount or 0
total_tokens = usage.totalTokenCount or (prompt_tokens + completion_tokens)
prompt_tokens = usage.prompt_token_count or 0
completion_tokens = usage.response_token_count or 0
total_tokens = usage.total_token_count or (prompt_tokens + completion_tokens)
tokens = LLMTokenUsage(
prompt_tokens=prompt_tokens,