Merge pull request #2662 from pelguetat/fix-vertex-ai-global-location-support

feat: add support for global location in Vertex AI base URL
This commit is contained in:
Mark Backman
2025-09-18 10:25:10 -07:00
committed by GitHub
2 changed files with 20 additions and 4 deletions

View File

@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added support for global location in `GoogleVertexLLMService`. The service now
supports both regional locations (e.g., "us-east4") and the "global" location
for Vertex AI endpoints. When using "global" location, the service will use
`aiplatform.googleapis.com` as the API host instead of the regional format.
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
fired when the pipeline is done running. This can be the result of a
`StopFrame`, `CancelFrame` or `EndFrame`.

View File

@@ -83,14 +83,23 @@ class GoogleVertexLLMService(OpenAILLMService):
self._api_key = self._get_api_token(credentials, credentials_path)
super().__init__(
api_key=self._api_key, base_url=base_url, model=model, params=params, **kwargs
api_key=self._api_key,
base_url=base_url,
model=model,
params=params,
**kwargs,
)
@staticmethod
def _get_base_url(params: InputParams) -> str:
"""Construct the base URL for Vertex AI API."""
# Determine the correct API host based on location
if params.location == "global":
api_host = "aiplatform.googleapis.com"
else:
api_host = f"{params.location}-aiplatform.googleapis.com"
return (
f"https://{params.location}-aiplatform.googleapis.com/v1/"
f"https://{api_host}/v1/"
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
)
@@ -118,12 +127,14 @@ class GoogleVertexLLMService(OpenAILLMService):
if credentials:
# Parse and load credentials from JSON string
creds = service_account.Credentials.from_service_account_info(
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
json.loads(credentials),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
elif credentials_path:
# Load credentials from JSON file
creds = service_account.Credentials.from_service_account_file(
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
credentials_path,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
try: