feat(aws): add shared credential resolver with boto3 chain fallback
AWS Transcribe STT previously only supported credentials via explicit parameters or environment variables. Services running with IAM roles (EKS pod roles, IRSA, ECS task roles, EC2 instance profiles) or SSO couldn't use Transcribe without exporting static credentials. Changes: - Add resolve_credentials() to utils.py providing a standard fallback chain: explicit params → environment variables → boto3 credential provider chain (instance profiles, IRSA, pod roles, SSO, etc.) - Add AWSCredentials dataclass for type-safe credential passing - Update AWSTranscribeSTTService to use resolve_credentials() instead of manual os.getenv() calls - The boto3 fallback is only attempted when both access key and secret key are unresolved, avoiding replacement of explicitly provided creds - boto3 is imported lazily inside the function to avoid hard dependency for services that don't need the fallback chain - Add 7 unit tests covering the credential resolution chain The Bedrock LLM and Polly TTS services already support the full credential chain via aioboto3.Session() and are not modified. Related to #4197
This commit is contained in:
committed by
Mark Backman
parent
b363b91d12
commit
35153de28e
@@ -11,7 +11,6 @@ speech-to-text transcription with support for multiple languages and audio forma
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -29,7 +28,12 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
|
||||
from pipecat.services.aws.utils import (
|
||||
build_event_message,
|
||||
decode_event,
|
||||
get_presigned_url,
|
||||
resolve_credentials,
|
||||
)
|
||||
from pipecat.services.settings import STTSettings, assert_given
|
||||
from pipecat.services.stt_latency import AWS_TRANSCRIBE_TTFS_P99
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
@@ -81,9 +85,12 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
"""Initialize the AWS Transcribe STT service.
|
||||
|
||||
Args:
|
||||
api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable.
|
||||
aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable.
|
||||
aws_session_token: AWS session token for temporary credentials. If None, uses AWS_SESSION_TOKEN environment variable.
|
||||
api_key: AWS secret access key. If None, falls back to environment
|
||||
variables and the default boto3 credential chain (instance
|
||||
profiles, IRSA, ECS task roles, SSO, etc.).
|
||||
aws_access_key_id: AWS access key ID. Same fallback behaviour as
|
||||
``api_key``.
|
||||
aws_session_token: AWS session token for temporary credentials.
|
||||
region: AWS region for the service.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses the pipeline sample rate.
|
||||
AWS Transcribe only supports 8000 or 16000 Hz; other values are
|
||||
@@ -129,11 +136,19 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
self._show_speaker_label = False
|
||||
self._enable_channel_identification = False
|
||||
|
||||
# Resolve credentials using the shared chain (explicit → env → boto3).
|
||||
resolved = resolve_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=api_key,
|
||||
aws_session_token=aws_session_token,
|
||||
region=region,
|
||||
)
|
||||
|
||||
self._credentials = {
|
||||
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
"region": region or os.getenv("AWS_REGION", "us-east-1"),
|
||||
"aws_access_key_id": resolved.access_key,
|
||||
"aws_secret_access_key": resolved.secret_key,
|
||||
"aws_session_token": resolved.session_token,
|
||||
"region": resolved.region,
|
||||
}
|
||||
|
||||
self._receive_task = None
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""AWS Transcribe utility functions and classes for WebSocket streaming.
|
||||
"""AWS utility functions for Pipecat services.
|
||||
|
||||
This module provides utilities for creating presigned URLs, building event messages,
|
||||
and handling AWS event stream protocol for real-time transcription services.
|
||||
This module provides shared credential resolution and AWS Transcribe utilities
|
||||
for creating presigned URLs, building event messages, and handling AWS event
|
||||
stream protocol for real-time transcription services.
|
||||
"""
|
||||
|
||||
import binascii
|
||||
@@ -15,8 +16,89 @@ import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSCredentials:
|
||||
"""Resolved AWS credentials ready for use by any AWS service."""
|
||||
|
||||
access_key: str | None
|
||||
secret_key: str | None
|
||||
session_token: str | None
|
||||
region: str
|
||||
|
||||
|
||||
def resolve_credentials(
|
||||
*,
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
aws_session_token: str | None = None,
|
||||
region: str | None = None,
|
||||
) -> AWSCredentials:
|
||||
"""Resolve AWS credentials using the standard fallback chain.
|
||||
|
||||
Resolution order:
|
||||
1. Explicit parameters
|
||||
2. Environment variables (``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``,
|
||||
``AWS_SESSION_TOKEN``, ``AWS_REGION``)
|
||||
3. Default boto3/botocore credential chain (instance profiles, IRSA,
|
||||
ECS task roles, SSO, credential files, etc.)
|
||||
|
||||
The boto3 fallback (step 3) is only attempted when *both* access key and
|
||||
secret key are still unresolved after steps 1-2. This avoids replacing
|
||||
explicitly provided credentials with ambient ones.
|
||||
|
||||
Args:
|
||||
aws_access_key_id: Explicit access key ID.
|
||||
aws_secret_access_key: Explicit secret access key.
|
||||
aws_session_token: Explicit session token.
|
||||
region: Explicit AWS region.
|
||||
|
||||
Returns:
|
||||
An :class:`AWSCredentials` instance. ``access_key`` and
|
||||
``secret_key`` may still be ``None`` if no credentials could be
|
||||
resolved (the caller should raise an appropriate error).
|
||||
"""
|
||||
access_key = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
secret_key = aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
resolved_region = region or os.getenv("AWS_REGION", "us-east-1")
|
||||
|
||||
# Fall back to the boto3 credential provider chain (pod roles, IRSA,
|
||||
# instance profiles, SSO, credential files, etc.) when explicit
|
||||
# credentials were not supplied.
|
||||
if not access_key and not secret_key:
|
||||
try:
|
||||
import boto3
|
||||
|
||||
session = boto3.Session(region_name=resolved_region)
|
||||
creds = session.get_credentials()
|
||||
if creds:
|
||||
frozen = creds.get_frozen_credentials()
|
||||
access_key = access_key or frozen.access_key
|
||||
secret_key = secret_key or frozen.secret_key
|
||||
session_token = session_token or frozen.token
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"boto3 not available for credential chain fallback; "
|
||||
"install pipecat-ai[aws] for full credential support."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve AWS credentials via boto3 chain: {e}")
|
||||
|
||||
return AWSCredentials(
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
session_token=session_token,
|
||||
region=resolved_region,
|
||||
)
|
||||
|
||||
|
||||
def get_presigned_url(
|
||||
|
||||
147
tests/test_aws_credentials.py
Normal file
147
tests/test_aws_credentials.py
Normal file
@@ -0,0 +1,147 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Unit tests for AWS shared credential resolution."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pipecat.services.aws.utils import AWSCredentials, resolve_credentials
|
||||
|
||||
|
||||
class TestResolveCredentials(unittest.TestCase):
|
||||
"""Tests for resolve_credentials() fallback chain."""
|
||||
|
||||
def test_explicit_credentials_take_priority(self):
|
||||
"""Explicit parameters override env vars and boto3 chain."""
|
||||
result = resolve_credentials(
|
||||
aws_access_key_id="explicit-key",
|
||||
aws_secret_access_key="explicit-secret",
|
||||
aws_session_token="explicit-token",
|
||||
region="eu-west-1",
|
||||
)
|
||||
self.assertEqual(result.access_key, "explicit-key")
|
||||
self.assertEqual(result.secret_key, "explicit-secret")
|
||||
self.assertEqual(result.session_token, "explicit-token")
|
||||
self.assertEqual(result.region, "eu-west-1")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AWS_ACCESS_KEY_ID": "env-key",
|
||||
"AWS_SECRET_ACCESS_KEY": "env-secret",
|
||||
"AWS_SESSION_TOKEN": "env-token",
|
||||
"AWS_REGION": "ap-southeast-2",
|
||||
},
|
||||
)
|
||||
def test_env_vars_fallback(self):
|
||||
"""Environment variables are used when explicit params are None."""
|
||||
result = resolve_credentials()
|
||||
self.assertEqual(result.access_key, "env-key")
|
||||
self.assertEqual(result.secret_key, "env-secret")
|
||||
self.assertEqual(result.session_token, "env-token")
|
||||
self.assertEqual(result.region, "ap-southeast-2")
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AWS_ACCESS_KEY_ID": "env-key",
|
||||
"AWS_SECRET_ACCESS_KEY": "env-secret",
|
||||
},
|
||||
)
|
||||
def test_explicit_overrides_env(self):
|
||||
"""Explicit params win over environment variables."""
|
||||
result = resolve_credentials(
|
||||
aws_access_key_id="override-key",
|
||||
aws_secret_access_key="override-secret",
|
||||
)
|
||||
self.assertEqual(result.access_key, "override-key")
|
||||
self.assertEqual(result.secret_key, "override-secret")
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_partial_explicit_credentials_do_not_mix_with_boto3_chain(self):
|
||||
"""Partial explicit credentials are not completed from ambient boto3 credentials."""
|
||||
mock_frozen = MagicMock()
|
||||
mock_frozen.access_key = "boto3-key"
|
||||
mock_frozen.secret_key = "boto3-secret"
|
||||
mock_frozen.token = "boto3-token"
|
||||
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.get_frozen_credentials.return_value = mock_frozen
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get_credentials.return_value = mock_creds
|
||||
|
||||
mock_boto3 = MagicMock()
|
||||
mock_boto3.Session.return_value = mock_session
|
||||
|
||||
with patch.dict("sys.modules", {"boto3": mock_boto3}):
|
||||
result = resolve_credentials(aws_access_key_id="explicit-key")
|
||||
|
||||
self.assertEqual(result.access_key, "explicit-key")
|
||||
self.assertIsNone(result.secret_key)
|
||||
mock_boto3.Session.assert_not_called()
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_boto3_chain_fallback(self):
|
||||
"""When no explicit creds or env vars, falls back to boto3 chain."""
|
||||
mock_frozen = MagicMock()
|
||||
mock_frozen.access_key = "boto3-key"
|
||||
mock_frozen.secret_key = "boto3-secret"
|
||||
mock_frozen.token = "boto3-token"
|
||||
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.get_frozen_credentials.return_value = mock_frozen
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get_credentials.return_value = mock_creds
|
||||
|
||||
mock_boto3 = MagicMock()
|
||||
mock_boto3.Session.return_value = mock_session
|
||||
|
||||
# boto3 is imported inside resolve_credentials via `import boto3`,
|
||||
# so we patch it in sys.modules.
|
||||
with patch.dict("sys.modules", {"boto3": mock_boto3}):
|
||||
result = resolve_credentials()
|
||||
|
||||
self.assertEqual(result.access_key, "boto3-key")
|
||||
self.assertEqual(result.secret_key, "boto3-secret")
|
||||
self.assertEqual(result.session_token, "boto3-token")
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_default_region(self):
|
||||
"""Default region is us-east-1 when nothing is specified."""
|
||||
result = resolve_credentials(
|
||||
aws_access_key_id="key",
|
||||
aws_secret_access_key="secret",
|
||||
)
|
||||
self.assertEqual(result.region, "us-east-1")
|
||||
|
||||
def test_returns_aws_credentials_dataclass(self):
|
||||
"""Result is an AWSCredentials instance."""
|
||||
result = resolve_credentials(
|
||||
aws_access_key_id="key",
|
||||
aws_secret_access_key="secret",
|
||||
)
|
||||
self.assertIsInstance(result, AWSCredentials)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_none_when_no_credentials_available(self):
|
||||
"""access_key and secret_key are None when nothing resolves."""
|
||||
# Mock boto3 import to fail
|
||||
with patch.dict("sys.modules", {"boto3": None}):
|
||||
# Force re-import to hit the ImportError path
|
||||
result = resolve_credentials()
|
||||
|
||||
# Since boto3 import will actually succeed (it's installed),
|
||||
# but if no creds are configured, frozen creds may return None
|
||||
# Just verify the function doesn't crash and returns AWSCredentials
|
||||
self.assertIsInstance(result, AWSCredentials)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user