diff --git a/src/pipecat/services/gemini_multimodal_live/events.py b/src/pipecat/services/gemini_multimodal_live/events.py index 1766cf806..0d5ce4de5 100644 --- a/src/pipecat/services/gemini_multimodal_live/events.py +++ b/src/pipecat/services/gemini_multimodal_live/events.py @@ -6,525 +6,22 @@ """Event models and utilities for Google Gemini Multimodal Live API.""" -import base64 -import io -import json -from enum import Enum -from typing import List, Literal, Optional - -from PIL import Image -from pydantic import BaseModel, Field - -from pipecat.frames.frames import ImageRawFrame - -# -# Client events -# - - -class MediaChunk(BaseModel): - """Represents a chunk of media data for transmission. - - Parameters: - mimeType: MIME type of the media content. - data: Base64-encoded media data. - """ - - mimeType: str - data: str - - -class ContentPart(BaseModel): - """Represents a part of content that can contain text or media. - - Parameters: - text: Text content. Defaults to None. - inlineData: Inline media data. Defaults to None. - """ - - text: Optional[str] = Field(default=None, validate_default=False) - inlineData: Optional[MediaChunk] = Field(default=None, validate_default=False) - fileData: Optional["FileData"] = Field(default=None, validate_default=False) - - -class FileData(BaseModel): - """Represents a file reference in the Gemini File API.""" - - mimeType: str - fileUri: str - - -ContentPart.model_rebuild() # Rebuild model to resolve forward reference - - -class Turn(BaseModel): - """Represents a conversational turn in the dialogue. - - Parameters: - role: The role of the speaker, either "user" or "model". Defaults to "user". - parts: List of content parts that make up the turn. - """ - - role: Literal["user", "model"] = "user" - parts: List[ContentPart] - - -class StartSensitivity(str, Enum): - """Determines how start of speech is detected.""" - - UNSPECIFIED = "START_SENSITIVITY_UNSPECIFIED" # Default is HIGH - HIGH = "START_SENSITIVITY_HIGH" # Detect start of speech more often - LOW = "START_SENSITIVITY_LOW" # Detect start of speech less often - - -class EndSensitivity(str, Enum): - """Determines how end of speech is detected.""" - - UNSPECIFIED = "END_SENSITIVITY_UNSPECIFIED" # Default is HIGH - HIGH = "END_SENSITIVITY_HIGH" # End speech more often - LOW = "END_SENSITIVITY_LOW" # End speech less often - - -class AutomaticActivityDetection(BaseModel): - """Configures automatic detection of voice activity. - - Parameters: - disabled: Whether automatic activity detection is disabled. Defaults to None. - start_of_speech_sensitivity: Sensitivity for detecting speech start. Defaults to None. - prefix_padding_ms: Padding before speech start in milliseconds. Defaults to None. - end_of_speech_sensitivity: Sensitivity for detecting speech end. Defaults to None. - silence_duration_ms: Duration of silence to detect speech end. Defaults to None. - """ - - disabled: Optional[bool] = None - start_of_speech_sensitivity: Optional[StartSensitivity] = None - prefix_padding_ms: Optional[int] = None - end_of_speech_sensitivity: Optional[EndSensitivity] = None - silence_duration_ms: Optional[int] = None - - -class RealtimeInputConfig(BaseModel): - """Configures the realtime input behavior. - - Parameters: - automatic_activity_detection: Voice activity detection configuration. Defaults to None. - """ - - automatic_activity_detection: Optional[AutomaticActivityDetection] = None - - -class RealtimeInput(BaseModel): - """Contains realtime input media chunks and text. - - Parameters: - mediaChunks: List of media chunks for realtime processing. - text: Text for realtime processing. - """ - - mediaChunks: Optional[List[MediaChunk]] = None - text: Optional[str] = None - - -class ClientContent(BaseModel): - """Content sent from client to the Gemini Live API. - - Parameters: - turns: List of conversation turns. Defaults to None. - turnComplete: Whether the client's turn is complete. Defaults to False. - """ - - turns: Optional[List[Turn]] = None - turnComplete: bool = False - - -class AudioInputMessage(BaseModel): - """Message containing audio input data. - - Parameters: - realtimeInput: Realtime input containing audio chunks. - """ - - realtimeInput: RealtimeInput - - @classmethod - def from_raw_audio(cls, raw_audio: bytes, sample_rate: int) -> "AudioInputMessage": - """Create an audio input message from raw audio data. - - Args: - raw_audio: Raw audio bytes. - sample_rate: Audio sample rate in Hz. - - Returns: - AudioInputMessage instance with encoded audio data. - """ - data = base64.b64encode(raw_audio).decode("utf-8") - return cls( - realtimeInput=RealtimeInput( - mediaChunks=[MediaChunk(mimeType=f"audio/pcm;rate={sample_rate}", data=data)] - ) - ) - - -class VideoInputMessage(BaseModel): - """Message containing video/image input data. - - Parameters: - realtimeInput: Realtime input containing video/image chunks. - """ - - realtimeInput: RealtimeInput - - @classmethod - def from_image_frame(cls, frame: ImageRawFrame) -> "VideoInputMessage": - """Create a video input message from an image frame. - - Args: - frame: Image frame to encode. - - Returns: - VideoInputMessage instance with encoded image data. - """ - buffer = io.BytesIO() - Image.frombytes(frame.format, frame.size, frame.image).save(buffer, format="JPEG") - data = base64.b64encode(buffer.getvalue()).decode("utf-8") - return cls( - realtimeInput=RealtimeInput(mediaChunks=[MediaChunk(mimeType=f"image/jpeg", data=data)]) - ) - - -class TextInputMessage(BaseModel): - """Message containing text input data.""" - - realtimeInput: RealtimeInput - - @classmethod - def from_text(cls, text: str) -> "TextInputMessage": - """Create a text input message from a string. - - Args: - text: The text to send. - - Returns: - A TextInputMessage instance. - """ - return cls(realtimeInput=RealtimeInput(text=text)) - - -class ClientContentMessage(BaseModel): - """Message containing client content for the API. - - Parameters: - clientContent: The client content to send. - """ - - clientContent: ClientContent - - -class SystemInstruction(BaseModel): - """System instruction for the model. - - Parameters: - parts: List of content parts that make up the system instruction. - """ - - parts: List[ContentPart] - - -class AudioTranscriptionConfig(BaseModel): - """Configuration for audio transcription.""" - - pass - - -class Setup(BaseModel): - """Setup configuration for the Gemini Live session. - - Parameters: - model: Model identifier to use. - system_instruction: System instruction for the model. Defaults to None. - tools: List of available tools/functions. Defaults to None. - generation_config: Generation configuration parameters. Defaults to None. - input_audio_transcription: Input audio transcription config. Defaults to None. - output_audio_transcription: Output audio transcription config. Defaults to None. - realtime_input_config: Realtime input configuration. Defaults to None. - """ - - model: str - system_instruction: Optional[SystemInstruction] = None - tools: Optional[List[dict]] = None - generation_config: Optional[dict] = None - input_audio_transcription: Optional[AudioTranscriptionConfig] = None - output_audio_transcription: Optional[AudioTranscriptionConfig] = None - realtime_input_config: Optional[RealtimeInputConfig] = None - - -class Config(BaseModel): - """Configuration message for session setup. - - Parameters: - setup: Setup configuration for the session. - """ - - setup: Setup - - -# -# Grounding metadata models -# - - -class SearchEntryPoint(BaseModel): - """Represents the search entry point with rendered content for search suggestions.""" - - renderedContent: Optional[str] = None - - -class WebSource(BaseModel): - """Represents a web source from grounding chunks.""" - - uri: Optional[str] = None - title: Optional[str] = None - - -class GroundingChunk(BaseModel): - """Represents a grounding chunk containing web source information.""" - - web: Optional[WebSource] = None - - -class GroundingSegment(BaseModel): - """Represents a segment of text that is grounded.""" - - startIndex: Optional[int] = None - endIndex: Optional[int] = None - text: Optional[str] = None - - -class GroundingSupport(BaseModel): - """Represents support information for grounded text segments.""" - - segment: Optional[GroundingSegment] = None - groundingChunkIndices: Optional[List[int]] = None - confidenceScores: Optional[List[float]] = None - - -class GroundingMetadata(BaseModel): - """Represents grounding metadata from Google Search.""" - - searchEntryPoint: Optional[SearchEntryPoint] = None - groundingChunks: Optional[List[GroundingChunk]] = None - groundingSupports: Optional[List[GroundingSupport]] = None - webSearchQueries: Optional[List[str]] = None - - -# -# Server events -# - - -class SetupComplete(BaseModel): - """Indicates that session setup is complete.""" - - pass - - -class InlineData(BaseModel): - """Inline data embedded in server responses. - - Parameters: - mimeType: MIME type of the data. - data: Base64-encoded data content. - """ - - mimeType: str - data: str - - -class Part(BaseModel): - """Part of a server response containing data or text. - - Parameters: - inlineData: Inline binary data. Defaults to None. - text: Text content. Defaults to None. - """ - - inlineData: Optional[InlineData] = None - text: Optional[str] = None - - -class ModelTurn(BaseModel): - """Represents a turn from the model in the conversation. - - Parameters: - parts: List of content parts in the model's response. - """ - - parts: List[Part] - - -class ServerContentInterrupted(BaseModel): - """Indicates server content was interrupted. - - Parameters: - interrupted: Whether the content was interrupted. - """ - - interrupted: bool - - -class ServerContentTurnComplete(BaseModel): - """Indicates the server's turn is complete. - - Parameters: - turnComplete: Whether the turn is complete. - """ - - turnComplete: bool - - -class BidiGenerateContentTranscription(BaseModel): - """Transcription data from bidirectional content generation. - - Parameters: - text: The transcribed text content. - """ - - text: str - - -class ServerContent(BaseModel): - """Content sent from server to client. - - Parameters: - modelTurn: Model's conversational turn. Defaults to None. - interrupted: Whether content was interrupted. Defaults to None. - turnComplete: Whether the turn is complete. Defaults to None. - inputTranscription: Transcription of input audio. Defaults to None. - outputTranscription: Transcription of output audio. Defaults to None. - """ - - modelTurn: Optional[ModelTurn] = None - interrupted: Optional[bool] = None - turnComplete: Optional[bool] = None - inputTranscription: Optional[BidiGenerateContentTranscription] = None - outputTranscription: Optional[BidiGenerateContentTranscription] = None - groundingMetadata: Optional[GroundingMetadata] = None - - -class FunctionCall(BaseModel): - """Represents a function call from the model. - - Parameters: - id: Unique identifier for the function call. - name: Name of the function to call. - args: Arguments to pass to the function. - """ - - id: str - name: str - args: dict - - -class ToolCall(BaseModel): - """Contains one or more function calls. - - Parameters: - functionCalls: List of function calls to execute. - """ - - functionCalls: List[FunctionCall] - - -class Modality(str, Enum): - """Modality types in token counts.""" - - UNSPECIFIED = "MODALITY_UNSPECIFIED" - TEXT = "TEXT" - IMAGE = "IMAGE" - AUDIO = "AUDIO" - VIDEO = "VIDEO" - - -class ModalityTokenCount(BaseModel): - """Token count for a specific modality. - - Parameters: - modality: The modality type. - tokenCount: Number of tokens for this modality. - """ - - modality: Modality - tokenCount: int - - -class UsageMetadata(BaseModel): - """Usage metadata about the API response. - - Parameters: - promptTokenCount: Number of tokens in the prompt. Defaults to None. - cachedContentTokenCount: Number of cached content tokens. Defaults to None. - responseTokenCount: Number of tokens in the response. Defaults to None. - toolUsePromptTokenCount: Number of tokens for tool use prompts. Defaults to None. - thoughtsTokenCount: Number of tokens for model thoughts. Defaults to None. - totalTokenCount: Total number of tokens used. Defaults to None. - promptTokensDetails: Detailed breakdown of prompt tokens by modality. Defaults to None. - cacheTokensDetails: Detailed breakdown of cache tokens by modality. Defaults to None. - responseTokensDetails: Detailed breakdown of response tokens by modality. Defaults to None. - toolUsePromptTokensDetails: Detailed breakdown of tool use tokens by modality. Defaults to None. - """ - - promptTokenCount: Optional[int] = None - cachedContentTokenCount: Optional[int] = None - responseTokenCount: Optional[int] = None - toolUsePromptTokenCount: Optional[int] = None - thoughtsTokenCount: Optional[int] = None - totalTokenCount: Optional[int] = None - promptTokensDetails: Optional[List[ModalityTokenCount]] = None - cacheTokensDetails: Optional[List[ModalityTokenCount]] = None - responseTokensDetails: Optional[List[ModalityTokenCount]] = None - toolUsePromptTokensDetails: Optional[List[ModalityTokenCount]] = None - - -class ServerEvent(BaseModel): - """Server event received from the Gemini Live API. - - Parameters: - setupComplete: Setup completion notification. Defaults to None. - serverContent: Content from the server. Defaults to None. - toolCall: Tool/function call request. Defaults to None. - usageMetadata: Token usage metadata. Defaults to None. - """ - - setupComplete: Optional[SetupComplete] = None - serverContent: Optional[ServerContent] = None - toolCall: Optional[ToolCall] = None - usageMetadata: Optional[UsageMetadata] = None - - -def parse_server_event(str): - """Parse a server event from JSON string. - - Args: - str: JSON string containing the server event. - - Returns: - ServerEvent instance if parsing succeeds, None otherwise. - """ - try: - evt = json.loads(str) - return ServerEvent.model_validate(evt) - except Exception as e: - print(f"Error parsing server event: {e}") - return None - - -class ContextWindowCompressionConfig(BaseModel): - """Configuration for context window compression. - - Parameters: - sliding_window: Whether to use sliding window compression. Defaults to True. - trigger_tokens: Token count threshold to trigger compression. Defaults to None. - """ - - sliding_window: Optional[bool] = Field(default=True) - trigger_tokens: Optional[int] = Field(default=None) +from loguru import logger + +try: + from google.genai.types import ( + EndSensitivity as _EndSensitivity, + ) + from google.genai.types import ( + StartSensitivity as _StartSensitivity, + ) +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.") + raise Exception(f"Missing module: {e}") + +# These aliases are just here for backward compatibility, since we used to +# define public-facing StartSensitivity and EndSensitivity enums in this +# module. +StartSensitivity = _StartSensitivity +EndSensitivity = _EndSensitivity diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index df560358f..7980f76ba 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -12,6 +12,7 @@ voice transcription, streaming responses, and tool usage. """ import base64 +import io import json import time from dataclasses import dataclass @@ -19,6 +20,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union from loguru import logger +from PIL import Image from pydantic import BaseModel, Field from pipecat.adapters.schemas.tools_schema import ToolsSchema @@ -28,7 +30,6 @@ from pipecat.frames.frames import ( BotStoppedSpeakingFrame, CancelFrame, EndFrame, - ErrorFrame, Frame, InputAudioRawFrame, InputImageRawFrame, @@ -72,11 +73,33 @@ from pipecat.utils.string import match_endofsentence from pipecat.utils.time import time_now_iso8601 from pipecat.utils.tracing.service_decorators import traced_gemini_live, traced_stt -from . import events from .file_api import GeminiFileAPI try: - from websockets.asyncio.client import connect as websocket_connect + from google.genai import Client + from google.genai.live import AsyncSession + from google.genai.types import ( + AudioTranscriptionConfig, + AutomaticActivityDetection, + Blob, + Content, + ContextWindowCompressionConfig, + EndSensitivity, + FileData, + FunctionResponse, + GenerationConfig, + GroundingMetadata, + HttpOptions, + LiveConnectConfig, + LiveServerMessage, + Modality, + Part, + RealtimeInputConfig, + SlidingWindow, + SpeechConfig, + StartSensitivity, + VoiceConfig, + ) except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.") @@ -246,13 +269,13 @@ class GeminiMultimodalLiveContext(OpenAILLMContext): self.messages.append(message) logger.info(f"Added file reference to context: {file_uri}") - def get_messages_for_initializing_history(self): + def get_messages_for_initializing_history(self) -> List[Content]: """Get messages formatted for Gemini history initialization. Returns: List of messages in Gemini format for conversation history. """ - messages = [] + messages: List[Content] = [] for item in self.messages: role = item.get("role") @@ -263,29 +286,28 @@ class GeminiMultimodalLiveContext(OpenAILLMContext): role = "model" content = item.get("content") - parts = [] + parts: List[Part] = [] if isinstance(content, str): - parts = [{"text": content}] + parts = [Part(text=content)] elif isinstance(content, list): for part in content: if part.get("type") == "text": - parts.append({"text": part.get("text")}) + parts.append(Part(text=part.get("text"))) elif part.get("type") == "file_data": file_data = part.get("file_data", {}) - parts.append( - { - "fileData": { - "mimeType": file_data.get("mime_type"), - "fileUri": file_data.get("file_uri"), - } - } + Part( + file_data=FileData( + mime_type=file_data.get("mime_type"), + file_uri=file_data.get("file_uri"), + ) + ) ) else: logger.warning(f"Unsupported content type: {str(part)[:80]}") else: logger.warning(f"Unsupported content type: {str(content)[:80]}") - messages.append({"role": role, "parts": parts}) + messages.append(Content(role=role, parts=parts)) return messages @@ -411,8 +433,8 @@ class GeminiVADParams(BaseModel): """ disabled: Optional[bool] = Field(default=None) - start_sensitivity: Optional[events.StartSensitivity] = Field(default=None) - end_sensitivity: Optional[events.EndSensitivity] = Field(default=None) + start_sensitivity: Optional[StartSensitivity] = Field(default=None) + end_sensitivity: Optional[EndSensitivity] = Field(default=None) prefix_padding_ms: Optional[int] = Field(default=None) silence_duration_ms: Optional[int] = Field(default=None) @@ -441,7 +463,7 @@ class InputParams(BaseModel): temperature: Sampling temperature (0.0-2.0). Defaults to None. top_k: Top-k sampling parameter. Must be >= 0. Defaults to None. top_p: Top-p sampling parameter (0.0-1.0). Defaults to None. - modalities: Response modalities. Defaults to AUDIO. + modalities: Response modalities. Defaults to "AUDIO". language: Language for generation. Defaults to EN_US. media_resolution: Media resolution setting. Defaults to UNSPECIFIED. vad: Voice activity detection parameters. Defaults to None. @@ -455,8 +477,8 @@ class InputParams(BaseModel): temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=0) top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) - modalities: Optional[GeminiMultimodalModalities] = Field( - default=GeminiMultimodalModalities.AUDIO + modalities: Optional[Union[List[GeminiMultimodalModalities], GeminiMultimodalModalities]] = ( + Field(default=GeminiMultimodalModalities.AUDIO) ) language: Optional[Language] = Field(default=Language.EN_US) media_resolution: Optional[GeminiMediaResolution] = Field( @@ -492,6 +514,7 @@ class GeminiMultimodalLiveLLMService(LLMService): params: Optional[InputParams] = None, inference_on_context_initialization: bool = True, file_api_base_url: str = "https://generativelanguage.googleapis.com/v1beta/files", + http_options: Optional[HttpOptions] = None, **kwargs, ): """Initialize the Gemini Multimodal Live LLM service. @@ -509,6 +532,7 @@ class GeminiMultimodalLiveLLMService(LLMService): inference_on_context_initialization: Whether to generate a response when context is first set. Defaults to True. file_api_base_url: Base URL for the Gemini File API. Defaults to the official endpoint. + http_options: HTTP options for the client. **kwargs: Additional arguments passed to parent LLMService. """ super().__init__(base_url=base_url, **kwargs) @@ -516,7 +540,6 @@ class GeminiMultimodalLiveLLMService(LLMService): params = params or InputParams() self._last_sent_time = 0 - self._api_key = api_key self._base_url = base_url self.set_model_name(model) self._voice_id = voice_id @@ -530,12 +553,12 @@ class GeminiMultimodalLiveLLMService(LLMService): self._audio_input_paused = start_audio_paused self._video_input_paused = start_video_paused self._context = None - self._websocket = None - self._receive_task = None + self._create_client(api_key, http_options) + self._session: AsyncSession = None + self._connection_task = None self._disconnecting = False - self._api_session_ready = False - self._run_llm_when_api_session_ready = False + self._run_llm_when_session_ready = False self._user_is_speaking = False self._bot_is_speaking = False @@ -578,6 +601,9 @@ class GeminiMultimodalLiveLLMService(LLMService): self._search_result_buffer = "" self._accumulated_grounding_metadata = None + def _create_client(self, api_key: str, http_options: Optional[HttpOptions] = None): + self._client = Client(api_key=api_key, http_options=http_options) + def can_generate_metrics(self) -> bool: """Check if the service can generate usage metrics. @@ -613,7 +639,9 @@ class GeminiMultimodalLiveLLMService(LLMService): """ self._video_input_paused = paused - def set_model_modalities(self, modalities: GeminiMultimodalModalities): + def set_model_modalities( + self, modalities: Union[List[GeminiMultimodalModalities], GeminiMultimodalModalities] + ): """Set the model response modalities. Args: @@ -656,7 +684,7 @@ class GeminiMultimodalLiveLLMService(LLMService): # async def start(self, frame: StartFrame): - """Start the service and establish websocket connection. + """Start the service and establish connection. Args: frame: The start frame. @@ -701,10 +729,9 @@ class GeminiMultimodalLiveLLMService(LLMService): await self.start_ttfb_metrics() if self._needs_turn_complete_message: self._needs_turn_complete_message = False - evt = events.ClientContentMessage.model_validate( - {"clientContent": {"turnComplete": True}} - ) - await self.send_client_event(evt) + # NOTE: without this, the model ignores the context it's been + # seeded with before the user started speaking + await self._session.send_client_content(turn_complete=True) # # frame processing @@ -768,6 +795,11 @@ class GeminiMultimodalLiveLLMService(LLMService): # ignore this frame. Use the serverContent.turnComplete API message await self.push_frame(frame, direction) elif isinstance(frame, LLMMessagesAppendFrame): + # NOTE: handling LLMMessagesAppendFrame here in the LLMService is + # unusual - typically this would be handled in the user context + # aggregator. Leaving this handling here so that user code that + # uses this frame *without* a user context aggregator still works + # (we have an example that does just that, actually). await self._create_single_response(frame.messages) elif isinstance(frame, LLMUpdateSettingsFrame): await self._update_settings(frame.settings) @@ -776,199 +808,171 @@ class GeminiMultimodalLiveLLMService(LLMService): else: await self.push_frame(frame, direction) - # - # websocket communication - # - - async def send_client_event(self, event): - """Send a client event to the Gemini Live API. - - Args: - event: The event to send. - """ - await self._ws_send(event.model_dump(exclude_none=True)) - async def _connect(self): - """Establish WebSocket connection to Gemini Live API.""" - if self._websocket: - # Here we assume that if we have a websocket, we are connected. We + """Establish client connection to Gemini Live API.""" + if self._session: + # Here we assume that if we have a client, we are connected. We # handle disconnections in the send/recv code paths. return logger.info("Connecting to Gemini service") try: - logger.info(f"Connecting to wss://{self._base_url}") - uri = f"wss://{self._base_url}?key={self._api_key}" - self._websocket = await websocket_connect(uri=uri) - self._receive_task = self.create_task(self._receive_task_handler()) + # Handle modalities being specified as either a list or a single value + modalities = self._settings["modalities"] + if isinstance(modalities, GeminiMultimodalModalities): + modalities = [modalities] - # Create the basic configuration - config_data = { - "setup": { - "model": self._model_name, - "generation_config": { - "frequency_penalty": self._settings["frequency_penalty"], - "max_output_tokens": self._settings["max_tokens"], - "presence_penalty": self._settings["presence_penalty"], - "temperature": self._settings["temperature"], - "top_k": self._settings["top_k"], - "top_p": self._settings["top_p"], - "response_modalities": self._settings["modalities"].value, - "speech_config": { - "voice_config": { - "prebuilt_voice_config": {"voice_name": self._voice_id} - }, - "language_code": self._settings["language"], - }, - "media_resolution": self._settings["media_resolution"].value, - }, - "input_audio_transcription": {}, - "output_audio_transcription": {}, - } - } + # Assemble basic configuration + config = LiveConnectConfig( + generation_config=GenerationConfig( + frequency_penalty=self._settings["frequency_penalty"], + max_output_tokens=self._settings["max_tokens"], + presence_penalty=self._settings["presence_penalty"], + temperature=self._settings["temperature"], + top_k=self._settings["top_k"], + top_p=self._settings["top_p"], + response_modalities=[Modality(modality.value) for modality in modalities], + speech_config=SpeechConfig( + voice_config=VoiceConfig( + prebuilt_voice_config={"voice_name": self._voice_id} + ), + language_code=self._settings["language"], + ), + media_resolution=self._settings["media_resolution"].value, + ), + input_audio_transcription=AudioTranscriptionConfig(), + output_audio_transcription=AudioTranscriptionConfig(), + ) - # Add context window compression if enabled + # Add context window compression to configuration, if enabled if self._settings.get("context_window_compression", {}).get("enabled", False): - compression_config = {} + compression_config = ContextWindowCompressionConfig() + # Add sliding window (always true if compression is enabled) - compression_config["sliding_window"] = {} + compression_config.sliding_window = SlidingWindow() # Add trigger_tokens if specified trigger_tokens = self._settings.get("context_window_compression", {}).get( "trigger_tokens" ) if trigger_tokens is not None: - compression_config["trigger_tokens"] = trigger_tokens + compression_config.trigger_tokens = trigger_tokens - config_data["setup"]["context_window_compression"] = compression_config + config.context_window_compression = compression_config - # Add VAD configuration if provided + # Add VAD configuration to configuration, if provided if self._settings.get("vad"): - vad_config = {} + vad_config = AutomaticActivityDetection() vad_params = self._settings["vad"] + has_vad_settings = False # Only add parameters that are explicitly set if vad_params.disabled is not None: - vad_config["disabled"] = vad_params.disabled + vad_config.disabled = vad_params.disabled + has_vad_settings = True if vad_params.start_sensitivity: - vad_config["start_of_speech_sensitivity"] = vad_params.start_sensitivity.value + vad_config.start_of_speech_sensitivity = vad_params.start_sensitivity.value + has_vad_settings = True if vad_params.end_sensitivity: - vad_config["end_of_speech_sensitivity"] = vad_params.end_sensitivity.value + vad_config.end_of_speech_sensitivity = vad_params.end_sensitivity.value + has_vad_settings = True if vad_params.prefix_padding_ms is not None: - vad_config["prefix_padding_ms"] = vad_params.prefix_padding_ms + vad_config.prefix_padding_ms = vad_params.prefix_padding_ms + has_vad_settings = True if vad_params.silence_duration_ms is not None: - vad_config["silence_duration_ms"] = vad_params.silence_duration_ms + vad_config.silence_duration_ms = vad_params.silence_duration_ms + has_vad_settings = True # Only add automatic_activity_detection if we have VAD settings - if vad_config: - realtime_config = {"automatic_activity_detection": vad_config} + if has_vad_settings: + config.realtime_input_config = RealtimeInputConfig( + automatic_activity_detection=vad_config + ) - config_data["setup"]["realtime_input_config"] = realtime_config - - config = events.Config.model_validate(config_data) - - # Add system instruction if available + # Add system instruction to configuration, if provided system_instruction = self._system_instruction or "" if self._context and hasattr(self._context, "extract_system_instructions"): system_instruction += "\n" + self._context.extract_system_instructions() if system_instruction: logger.debug(f"Setting system instruction: {system_instruction}") - config.setup.system_instruction = events.SystemInstruction( - parts=[events.ContentPart(text=system_instruction)] - ) + config.system_instruction = system_instruction - # Add tools if available + # Add tools to configuration, if provided if self._tools: - logger.debug(f"Gemini is configuring to use tools{self._tools}") - config.setup.tools = self.get_llm_adapter().from_standard_tools(self._tools) + logger.debug(f"Setting tools: {self._tools}") + config.tools = self.get_llm_adapter().from_standard_tools(self._tools) - # Send the configuration - await self.send_client_event(config) + # Start the connection + self._connection_task = self.create_task(self._connection_task_handler(config=config)) except Exception as e: logger.error(f"{self} initialization error: {e}") - self._websocket = None + self._session = None + + async def _connection_task_handler(self, config: LiveConnectConfig): + async with self._client.aio.live.connect(model=self._model_name, config=config) as session: + logger.info("Connected to Gemini service") + + await self._handle_session_ready(session) + + while True: + try: + turn = self._session.receive() + async for message in turn: + if message.server_content and message.server_content.model_turn: + await self._handle_msg_model_turn(message) + elif ( + message.server_content + and message.server_content.turn_complete + and message.usage_metadata + ): + await self._handle_msg_turn_complete(message) + await self._handle_msg_usage_metadata(message) + elif message.server_content and message.server_content.input_transcription: + await self._handle_msg_input_transcription(message) + elif message.server_content and message.server_content.output_transcription: + await self._handle_msg_output_transcription(message) + elif message.server_content and message.server_content.grounding_metadata: + await self._handle_msg_grounding_metadata(message) + elif message.tool_call: + await self._handle_msg_tool_call(message) + except Exception as e: + logger.error(f"Error in receive loop: {type(e)}: {e}") + break async def _disconnect(self): """Disconnect from Gemini Live API and clean up resources.""" logger.info("Disconnecting from Gemini service") try: self._disconnecting = True - self._api_session_ready = False await self.stop_all_metrics() - if self._websocket: - await self._websocket.close() - self._websocket = None - if self._receive_task: - await self.cancel_task(self._receive_task, timeout=1.0) - self._receive_task = None + if self._connection_task: + await self.cancel_task(self._connection_task, timeout=1.0) + self._connection_task = None + if self._session: + await self._session.close() + self._session = None self._disconnecting = False except Exception as e: logger.error(f"{self} error disconnecting: {e}") - async def _ws_send(self, message): - """Send a message to the WebSocket connection.""" - # logger.debug(f"Sending message to websocket: {message}") - try: - if self._websocket: - await self._websocket.send(json.dumps(message)) - except Exception as e: - if self._disconnecting: - return - logger.error(f"Error sending message to websocket: {e}") - # In server-to-server contexts, a WebSocket error should be quite rare. Given how hard - # it is to recover from a send-side error with proper state management, and that exponential - # backoff for retries can have cost/stability implications for a service cluster, let's just - # treat a send-side error as fatal. - await self.push_error(ErrorFrame(error=f"Error sending client event: {e}", fatal=True)) - - # - # inbound server event handling - # todo: docs link here - # - - async def _receive_task_handler(self): - """Handle incoming messages from the WebSocket connection.""" - async for message in self._websocket: - evt = events.parse_server_event(message) - # logger.debug(f"Received event: {message[:500]}") - # logger.debug(f"Received event: {evt}") - - if evt.setupComplete: - await self._handle_evt_setup_complete(evt) - elif evt.serverContent and evt.serverContent.modelTurn: - await self._handle_evt_model_turn(evt) - elif evt.serverContent and evt.serverContent.turnComplete and evt.usageMetadata: - await self._handle_evt_turn_complete(evt) - await self._handle_evt_usage_metadata(evt) - elif evt.serverContent and evt.serverContent.inputTranscription: - await self._handle_evt_input_transcription(evt) - elif evt.serverContent and evt.serverContent.outputTranscription: - await self._handle_evt_output_transcription(evt) - elif evt.serverContent and evt.serverContent.groundingMetadata: - await self._handle_evt_grounding_metadata(evt) - elif evt.toolCall: - await self._handle_evt_tool_call(evt) - elif False: # !!! todo: error events? - await self._handle_evt_error(evt) - # errors are fatal, so exit the receive loop - return - - # - # - # - async def _send_user_audio(self, frame): """Send user audio frame to Gemini Live API.""" if self._audio_input_paused: return + + if not self._session: + return + # Send all audio to Gemini - evt = events.AudioInputMessage.from_raw_audio(frame.audio, frame.sample_rate) - await self.send_client_event(evt) + await self._session.send_realtime_input( + audio=Blob(data=frame.audio, mime_type=f"audio/pcm;rate={frame.sample_rate}") + ) + # Manage a buffer of audio to use for transcription audio = frame.audio if self._user_is_speaking: @@ -993,27 +997,34 @@ class GeminiMultimodalLiveLLMService(LLMService): Args: text: The text to send as user input. """ - evt = events.TextInputMessage.from_text(text) - await self.send_client_event(evt) + if not self._session: + return + + await self._session.send_realtime_input(text=text) async def _send_user_video(self, frame): """Send user video frame to Gemini Live API.""" if self._video_input_paused: return + if not self._session: + return + now = time.time() if now - self._last_sent_time < 1: return # Ignore if less than 1 second has passed self._last_sent_time = now # Update last sent time logger.debug(f"Sending video frame to Gemini: {frame}") - evt = events.VideoInputMessage.from_image_frame(frame) - await self.send_client_event(evt) + buffer = io.BytesIO() + Image.frombytes(frame.format, frame.size, frame.image).save(buffer, format="JPEG") + data = base64.b64encode(buffer.getvalue()).decode("utf-8") + await self._session.send_realtime_input(video=Blob(data=data, mime_type="image/jpeg")) async def _create_initial_response(self): """Create initial response based on context history.""" - if not self._api_session_ready: - self._run_llm_when_api_session_ready = True + if not self._session: + self._run_llm_when_session_ready = True return messages = self._context.get_messages_for_initializing_history() @@ -1024,70 +1035,31 @@ class GeminiMultimodalLiveLLMService(LLMService): await self.start_ttfb_metrics() - evt = events.ClientContentMessage.model_validate( - { - "clientContent": { - "turns": messages, - "turnComplete": self._inference_on_context_initialization, - } - } + await self._session.send_client_content( + turns=messages, turn_complete=self._inference_on_context_initialization ) - await self.send_client_event(evt) + + # If we're generating a response right away upon initializing + # conversation history, set a flag saying that we need a turn complete + # message when the user stops speaking. if not self._inference_on_context_initialization: self._needs_turn_complete_message = True async def _create_single_response(self, messages_list): """Create a single response from a list of messages.""" - # Refactor to combine this logic with same logic in GeminiMultimodalLiveContext - messages = [] - for item in messages_list: - role = item.get("role") + # Create a throwaway context just for the purpose of getting messages + # in the right format + context = GeminiMultimodalLiveContext.upgrade(OpenAILLMContext(messages=messages_list)) + messages = context.get_messages_for_initializing_history() - if role == "system": - continue - - elif role == "assistant": - role = "model" - - content = item.get("content") - parts = [] - if isinstance(content, str): - parts = [{"text": content}] - elif isinstance(content, list): - for part in content: - if part.get("type") == "text": - parts.append({"text": part.get("text")}) - elif part.get("type") == "file_data": - file_data = part.get("file_data", {}) - - parts.append( - { - "fileData": { - "mimeType": file_data.get("mime_type"), - "fileUri": file_data.get("file_uri"), - } - } - ) - else: - logger.warning(f"Unsupported content type: {str(part)[:80]}") - else: - logger.warning(f"Unsupported content type: {str(content)[:80]}") - messages.append({"role": role, "parts": parts}) if not messages: return + logger.debug(f"Creating response: {messages}") await self.start_ttfb_metrics() - evt = events.ClientContentMessage.model_validate( - { - "clientContent": { - "turns": messages, - "turnComplete": True, - } - } - ) - await self.send_client_event(evt) + await self._session.send_client_content(turns=messages, turn_complete=True) @traced_gemini_live(operation="llm_tool_result") async def _tool_result(self, tool_result_message): @@ -1097,37 +1069,22 @@ class GeminiMultimodalLiveLLMService(LLMService): id = tool_result_message.get("tool_call_id") name = tool_result_message.get("tool_call_name") result = json.loads(tool_result_message.get("content") or "") - response_message = json.dumps( - { - "toolResponse": { - "functionResponses": [ - { - "id": id, - "name": name, - "response": { - "result": result, - }, - } - ], - } - } - ) - await self._websocket.send(response_message) - # await self._websocket.send(json.dumps({"clientContent": {"turnComplete": True}})) + response = FunctionResponse(name=name, id=id, response=result) + await self._session.send_tool_response(function_responses=response) @traced_gemini_live(operation="llm_setup") - async def _handle_evt_setup_complete(self, evt): - """Handle the setup complete event.""" - # If this is our first context frame, run the LLM - self._api_session_ready = True - # Now that we've configured the session, we can run the LLM if we need to. - if self._run_llm_when_api_session_ready: - self._run_llm_when_api_session_ready = False + async def _handle_session_ready(self, session: AsyncSession): + """Handle the session being ready.""" + self._session = session + # If we were just waititng for the session to be ready to run the LLM, + # do that now. + if self._run_llm_when_session_ready: + self._run_llm_when_session_ready = False await self._create_initial_response() - async def _handle_evt_model_turn(self, evt): - """Handle the model turn event.""" - part = evt.serverContent.modelTurn.parts[0] + async def _handle_msg_model_turn(self, msg: LiveServerMessage): + """Handle the model turn message.""" + part = msg.server_content.model_turn.parts[0] if not part: return @@ -1144,17 +1101,17 @@ class GeminiMultimodalLiveLLMService(LLMService): await self.push_frame(LLMTextFrame(text=text)) # Check for grounding metadata in server content - if evt.serverContent and evt.serverContent.groundingMetadata: - self._accumulated_grounding_metadata = evt.serverContent.groundingMetadata + if msg.server_content and msg.server_content.grounding_metadata: + self._accumulated_grounding_metadata = msg.server_content.grounding_metadata - inline_data = part.inlineData + inline_data = part.inline_data if not inline_data: return - if inline_data.mimeType != f"audio/pcm;rate={self._sample_rate}": - logger.warning(f"Unrecognized server_content format {inline_data.mimeType}") + if inline_data.mime_type != f"audio/pcm;rate={self._sample_rate}": + logger.warning(f"Unrecognized server_content format {inline_data.mime_type}") return - audio = base64.b64decode(inline_data.data) + audio = inline_data.data if not audio: return @@ -1172,9 +1129,9 @@ class GeminiMultimodalLiveLLMService(LLMService): await self.push_frame(frame) @traced_gemini_live(operation="llm_tool_call") - async def _handle_evt_tool_call(self, evt): - """Handle tool call events.""" - function_calls = evt.toolCall.functionCalls + async def _handle_msg_tool_call(self, message: LiveServerMessage): + """Handle tool call messages.""" + function_calls = message.tool_call.function_calls if not function_calls: return if not self._context: @@ -1193,12 +1150,13 @@ class GeminiMultimodalLiveLLMService(LLMService): await self.run_function_calls(function_calls_llm) @traced_gemini_live(operation="llm_response") - async def _handle_evt_turn_complete(self, evt): - """Handle the turn complete event.""" + async def _handle_msg_turn_complete(self, message: LiveServerMessage): + """Handle the turn complete message.""" self._bot_is_speaking = False text = self._bot_text_buffer # Determine output and modality for tracing + # TODO: looks like there's a bug here - output_text and output_modality are unused if text: # TEXT modality output_text = text @@ -1239,17 +1197,17 @@ class GeminiMultimodalLiveLLMService(LLMService): """Handle a transcription result with tracing.""" pass - async def _handle_evt_input_transcription(self, evt): - """Handle the input transcription event. + async def _handle_msg_input_transcription(self, message: LiveServerMessage): + """Handle the input transcription message. Gemini Live sends user transcriptions in either single words or multi-word phrases. As a result, we have to aggregate the input transcription. This handler aggregates into sentences, splitting on the end of sentence markers. """ - if not evt.serverContent.inputTranscription: + if not message.server_content.input_transcription: return - text = evt.serverContent.inputTranscription.text + text = message.server_content.input_transcription.text if not text: return @@ -1282,20 +1240,20 @@ class GeminiMultimodalLiveLLMService(LLMService): text=complete_sentence, user_id="", timestamp=time_now_iso8601(), - result=evt, + result=message, ), FrameDirection.UPSTREAM, ) - async def _handle_evt_output_transcription(self, evt): - """Handle the output transcription event.""" - if not evt.serverContent.outputTranscription: + async def _handle_msg_output_transcription(self, message: LiveServerMessage): + """Handle the output transcription message.""" + if not message.server_content.output_transcription: return # This is the output transcription text when modalities is set to AUDIO. # In this case, we push LLMTextFrame and TTSTextFrame to be handled by the # downstream assistant context aggregator. - text = evt.serverContent.outputTranscription.text + text = message.server_content.output_transcription.text if not text: return @@ -1304,23 +1262,23 @@ class GeminiMultimodalLiveLLMService(LLMService): self._search_result_buffer += text # Check for grounding metadata in server content - if evt.serverContent and evt.serverContent.groundingMetadata: - self._accumulated_grounding_metadata = evt.serverContent.groundingMetadata + if message.server_content and message.server_content.grounding_metadata: + self._accumulated_grounding_metadata = message.server_content.grounding_metadata # Collect text for tracing self._llm_output_buffer += text await self.push_frame(LLMTextFrame(text=text)) await self.push_frame(TTSTextFrame(text=text)) - async def _handle_evt_grounding_metadata(self, evt): - """Handle dedicated grounding metadata events.""" - if evt.serverContent and evt.serverContent.groundingMetadata: - grounding_metadata = evt.serverContent.groundingMetadata + async def _handle_msg_grounding_metadata(self, message: LiveServerMessage): + """Handle dedicated grounding metadata messages.""" + if message.server_content and message.server_content.grounding_metadata: + grounding_metadata = message.server_content.grounding_metadata # Process the grounding metadata immediately await self._process_grounding_metadata(grounding_metadata, self._search_result_buffer) async def _process_grounding_metadata( - self, grounding_metadata: events.GroundingMetadata, search_result: str = "" + self, grounding_metadata: GroundingMetadata, search_result: str = "" ): """Process grounding metadata and emit LLMSearchResponseFrame.""" if not grounding_metadata: @@ -1329,19 +1287,19 @@ class GeminiMultimodalLiveLLMService(LLMService): # Extract rendered content for search suggestions rendered_content = None if ( - grounding_metadata.searchEntryPoint - and grounding_metadata.searchEntryPoint.renderedContent + grounding_metadata.search_entry_point + and grounding_metadata.search_entry_point.rendered_content ): - rendered_content = grounding_metadata.searchEntryPoint.renderedContent + rendered_content = grounding_metadata.search_entry_point.rendered_content # Convert grounding chunks and supports to LLMSearchOrigin format origins = [] - if grounding_metadata.groundingChunks and grounding_metadata.groundingSupports: + if grounding_metadata.grounding_chunks and grounding_metadata.grounding_supports: # Create a mapping of chunk indices to origins - chunk_to_origin = {} + chunk_to_origin: Dict[int, LLMSearchOrigin] = {} - for index, chunk in enumerate(grounding_metadata.groundingChunks): + for index, chunk in enumerate(grounding_metadata.grounding_chunks): if chunk.web: origin = LLMSearchOrigin( site_uri=chunk.web.uri, site_title=chunk.web.title, results=[] @@ -1350,13 +1308,13 @@ class GeminiMultimodalLiveLLMService(LLMService): origins.append(origin) # Add grounding support results to the appropriate origins - for support in grounding_metadata.groundingSupports: - if support.segment and support.groundingChunkIndices: + for support in grounding_metadata.grounding_supports: + if support.segment and support.grounding_chunk_indices: text = support.segment.text or "" - confidence_scores = support.confidenceScores or [] + confidence_scores = support.confidence_scores or [] # Add this result to all origins referenced by this support - for chunk_index in support.groundingChunkIndices: + for chunk_index in support.grounding_chunk_indices: if chunk_index in chunk_to_origin: result = LLMSearchResult(text=text, confidence=confidence_scores) chunk_to_origin[chunk_index].results.append(result) @@ -1368,17 +1326,17 @@ class GeminiMultimodalLiveLLMService(LLMService): await self.push_frame(search_frame) - async def _handle_evt_usage_metadata(self, evt): - """Handle the usage metadata event.""" - if not evt.usageMetadata: + async def _handle_msg_usage_metadata(self, message: LiveServerMessage): + """Handle the usage metadata message.""" + if not message.usage_metadata: return - usage = evt.usageMetadata + usage = message.usage_metadata # Ensure we have valid integers for all token counts - prompt_tokens = usage.promptTokenCount or 0 - completion_tokens = usage.responseTokenCount or 0 - total_tokens = usage.totalTokenCount or (prompt_tokens + completion_tokens) + prompt_tokens = usage.prompt_token_count or 0 + completion_tokens = usage.response_token_count or 0 + total_tokens = usage.total_token_count or (prompt_tokens + completion_tokens) tokens = LLMTokenUsage( prompt_tokens=prompt_tokens,