diff --git a/CHANGELOG.md b/CHANGELOG.md index e47f1010a..b6001f82e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `AzureRealtimeBetaLLMService` to support Azure's OpeanAI Realtime API. Added foundational example `19a-azure-realtime-beta.py`. +- Introduced `GoogleVertexAIService`, a new class for integrating with Vertex AI + Gemini models. + ### Changed - Updated the default mode for `CartesiaTTSService` and diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index 9d4a0a8a8..44a677136 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -71,6 +71,7 @@ try: import google.generativeai as gai from google import genai from google.api_core.client_options import ClientOptions + from google.auth.transport.requests import Request from google.cloud import speech_v2, texttospeech_v1 from google.cloud.speech_v2.types import cloud_speech from google.genai import types @@ -1333,6 +1334,78 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService): ) +class GoogleVertexAIService(OpenAILLMService): + """Implements inference with Google's AI models via Vertex AI while maintaining OpenAI API compatibility. + Reference: + https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library + """ + + class InputParams(OpenAILLMService.InputParams): + """Input parameters specific to Vertex AI.""" + + project_id: str + location: str + + def __init__( + self, + *, + credentials: Optional[str] = None, + credentials_path: Optional[str] = None, + model: str = "google/gemini-1.5-flash", + params: InputParams = OpenAILLMService.InputParams(), + **kwargs, + ): + """Initializes the VertexLLMService. + Args: + credentials (Optional[str]): JSON string of service account credentials. + credentials_path (Optional[str]): Path to the service account JSON file. + model (str): Model identifier. Defaults to "google/gemini-1.5-flash". + params (InputParams): Vertex AI input parameters. + **kwargs: Additional arguments for OpenAILLMService. + """ + base_url = self._get_base_url(params) + self._api_key = self._get_api_token(credentials, credentials_path) + + super().__init__(api_key=self._api_key, base_url=base_url, model=model, **kwargs) + + @staticmethod + def _get_base_url(params: InputParams) -> str: + """Constructs the base URL for Vertex AI API.""" + return ( + f"https://{params.location}-aiplatform.googleapis.com/v1/" + f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi" + ) + + @staticmethod + def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str: + """Retrieves an authentication token using Google service account credentials. + Args: + credentials (Optional[str]): JSON string of service account credentials. + credentials_path (Optional[str]): Path to the service account JSON file. + Returns: + str: OAuth token for API authentication. + """ + creds: Optional[service_account.Credentials] = None + + if credentials: + # Parse and load credentials from JSON string + creds = service_account.Credentials.from_service_account_info( + json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + elif credentials_path: + # Load credentials from JSON file + creds = service_account.Credentials.from_service_account_file( + credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + + if not creds: + raise ValueError("No valid credentials provided.") + + creds.refresh(Request()) # Ensure token is up-to-date, lifetime is 1 hour. + + return creds.token + + class GoogleTTSService(TTSService): class InputParams(BaseModel): pitch: Optional[str] = None