adding vertex llm

This commit is contained in:
Vaibhav159
2025-03-01 00:20:16 +05:30
committed by Vaibhav Lodha
parent 8b86f6991d
commit fa7da8f5f6
2 changed files with 76 additions and 0 deletions

View File

@@ -88,6 +88,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `AzureRealtimeBetaLLMService` to support Azure's OpeanAI Realtime API. Added
foundational example `19a-azure-realtime-beta.py`.
- Introduced `GoogleVertexAIService`, a new class for integrating with Vertex AI
Gemini models.
### Changed
- Updated the default mode for `CartesiaTTSService` and

View File

@@ -71,6 +71,7 @@ try:
import google.generativeai as gai
from google import genai
from google.api_core.client_options import ClientOptions
from google.auth.transport.requests import Request
from google.cloud import speech_v2, texttospeech_v1
from google.cloud.speech_v2.types import cloud_speech
from google.genai import types
@@ -1333,6 +1334,78 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
)
class GoogleVertexAIService(OpenAILLMService):
"""Implements inference with Google's AI models via Vertex AI while maintaining OpenAI API compatibility.
Reference:
https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library
"""
class InputParams(OpenAILLMService.InputParams):
"""Input parameters specific to Vertex AI."""
project_id: str
location: str
def __init__(
self,
*,
credentials: Optional[str] = None,
credentials_path: Optional[str] = None,
model: str = "google/gemini-1.5-flash",
params: InputParams = OpenAILLMService.InputParams(),
**kwargs,
):
"""Initializes the VertexLLMService.
Args:
credentials (Optional[str]): JSON string of service account credentials.
credentials_path (Optional[str]): Path to the service account JSON file.
model (str): Model identifier. Defaults to "google/gemini-1.5-flash".
params (InputParams): Vertex AI input parameters.
**kwargs: Additional arguments for OpenAILLMService.
"""
base_url = self._get_base_url(params)
self._api_key = self._get_api_token(credentials, credentials_path)
super().__init__(api_key=self._api_key, base_url=base_url, model=model, **kwargs)
@staticmethod
def _get_base_url(params: InputParams) -> str:
"""Constructs the base URL for Vertex AI API."""
return (
f"https://{params.location}-aiplatform.googleapis.com/v1/"
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
)
@staticmethod
def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str:
"""Retrieves an authentication token using Google service account credentials.
Args:
credentials (Optional[str]): JSON string of service account credentials.
credentials_path (Optional[str]): Path to the service account JSON file.
Returns:
str: OAuth token for API authentication.
"""
creds: Optional[service_account.Credentials] = None
if credentials:
# Parse and load credentials from JSON string
creds = service_account.Credentials.from_service_account_info(
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
elif credentials_path:
# Load credentials from JSON file
creds = service_account.Credentials.from_service_account_file(
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
if not creds:
raise ValueError("No valid credentials provided.")
creds.refresh(Request()) # Ensure token is up-to-date, lifetime is 1 hour.
return creds.token
class GoogleTTSService(TTSService):
class InputParams(BaseModel):
pitch: Optional[str] = None