feat(aws): add shared credential resolver with boto3 chain fallback

AWS Transcribe STT previously only supported credentials via explicit
parameters or environment variables. Services running with IAM roles
(EKS pod roles, IRSA, ECS task roles, EC2 instance profiles) or SSO
couldn't use Transcribe without exporting static credentials.

Changes:
- Add resolve_credentials() to utils.py providing a standard fallback
  chain: explicit params → environment variables → boto3 credential
  provider chain (instance profiles, IRSA, pod roles, SSO, etc.)
- Add AWSCredentials dataclass for type-safe credential passing
- Update AWSTranscribeSTTService to use resolve_credentials() instead
  of manual os.getenv() calls
- The boto3 fallback is only attempted when both access key and secret
  key are unresolved, avoiding replacement of explicitly provided creds
- boto3 is imported lazily inside the function to avoid hard dependency
  for services that don't need the fallback chain
- Add 7 unit tests covering the credential resolution chain

The Bedrock LLM and Polly TTS services already support the full
credential chain via aioboto3.Session() and are not modified.

Related to #4197
This commit is contained in:
Daniel Wirjo
2026-04-19 12:55:41 +00:00
committed by Mark Backman
parent b363b91d12
commit 35153de28e
3 changed files with 256 additions and 12 deletions

View File

@@ -11,7 +11,6 @@ speech-to-text transcription with support for multiple languages and audio forma
"""
import json
import os
import random
import string
from collections.abc import AsyncGenerator
@@ -29,7 +28,12 @@ from pipecat.frames.frames import (
StartFrame,
TranscriptionFrame,
)
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
from pipecat.services.aws.utils import (
build_event_message,
decode_event,
get_presigned_url,
resolve_credentials,
)
from pipecat.services.settings import STTSettings, assert_given
from pipecat.services.stt_latency import AWS_TRANSCRIBE_TTFS_P99
from pipecat.services.stt_service import WebsocketSTTService
@@ -81,9 +85,12 @@ class AWSTranscribeSTTService(WebsocketSTTService):
"""Initialize the AWS Transcribe STT service.
Args:
api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable.
aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable.
aws_session_token: AWS session token for temporary credentials. If None, uses AWS_SESSION_TOKEN environment variable.
api_key: AWS secret access key. If None, falls back to environment
variables and the default boto3 credential chain (instance
profiles, IRSA, ECS task roles, SSO, etc.).
aws_access_key_id: AWS access key ID. Same fallback behaviour as
``api_key``.
aws_session_token: AWS session token for temporary credentials.
region: AWS region for the service.
sample_rate: Audio sample rate in Hz. If None, uses the pipeline sample rate.
AWS Transcribe only supports 8000 or 16000 Hz; other values are
@@ -129,11 +136,19 @@ class AWSTranscribeSTTService(WebsocketSTTService):
self._show_speaker_label = False
self._enable_channel_identification = False
# Resolve credentials using the shared chain (explicit → env → boto3).
resolved = resolve_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=api_key,
aws_session_token=aws_session_token,
region=region,
)
self._credentials = {
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
"region": region or os.getenv("AWS_REGION", "us-east-1"),
"aws_access_key_id": resolved.access_key,
"aws_secret_access_key": resolved.secret_key,
"aws_session_token": resolved.session_token,
"region": resolved.region,
}
self._receive_task = None

View File

@@ -4,10 +4,11 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""AWS Transcribe utility functions and classes for WebSocket streaming.
"""AWS utility functions for Pipecat services.
This module provides utilities for creating presigned URLs, building event messages,
and handling AWS event stream protocol for real-time transcription services.
This module provides shared credential resolution and AWS Transcribe utilities
for creating presigned URLs, building event messages, and handling AWS event
stream protocol for real-time transcription services.
"""
import binascii
@@ -15,8 +16,89 @@ import datetime
import hashlib
import hmac
import json
import os
import struct
import urllib.parse
from dataclasses import dataclass
from typing import Any
from loguru import logger
@dataclass
class AWSCredentials:
"""Resolved AWS credentials ready for use by any AWS service."""
access_key: str | None
secret_key: str | None
session_token: str | None
region: str
def resolve_credentials(
*,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_session_token: str | None = None,
region: str | None = None,
) -> AWSCredentials:
"""Resolve AWS credentials using the standard fallback chain.
Resolution order:
1. Explicit parameters
2. Environment variables (``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``,
``AWS_SESSION_TOKEN``, ``AWS_REGION``)
3. Default boto3/botocore credential chain (instance profiles, IRSA,
ECS task roles, SSO, credential files, etc.)
The boto3 fallback (step 3) is only attempted when *both* access key and
secret key are still unresolved after steps 1-2. This avoids replacing
explicitly provided credentials with ambient ones.
Args:
aws_access_key_id: Explicit access key ID.
aws_secret_access_key: Explicit secret access key.
aws_session_token: Explicit session token.
region: Explicit AWS region.
Returns:
An :class:`AWSCredentials` instance. ``access_key`` and
``secret_key`` may still be ``None`` if no credentials could be
resolved (the caller should raise an appropriate error).
"""
access_key = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
secret_key = aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
resolved_region = region or os.getenv("AWS_REGION", "us-east-1")
# Fall back to the boto3 credential provider chain (pod roles, IRSA,
# instance profiles, SSO, credential files, etc.) when explicit
# credentials were not supplied.
if not access_key and not secret_key:
try:
import boto3
session = boto3.Session(region_name=resolved_region)
creds = session.get_credentials()
if creds:
frozen = creds.get_frozen_credentials()
access_key = access_key or frozen.access_key
secret_key = secret_key or frozen.secret_key
session_token = session_token or frozen.token
except ImportError:
logger.debug(
"boto3 not available for credential chain fallback; "
"install pipecat-ai[aws] for full credential support."
)
except Exception as e:
logger.warning(f"Failed to resolve AWS credentials via boto3 chain: {e}")
return AWSCredentials(
access_key=access_key,
secret_key=secret_key,
session_token=session_token,
region=resolved_region,
)
def get_presigned_url(

View File

@@ -0,0 +1,147 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Unit tests for AWS shared credential resolution."""
import os
import unittest
from unittest.mock import MagicMock, patch
from pipecat.services.aws.utils import AWSCredentials, resolve_credentials
class TestResolveCredentials(unittest.TestCase):
"""Tests for resolve_credentials() fallback chain."""
def test_explicit_credentials_take_priority(self):
"""Explicit parameters override env vars and boto3 chain."""
result = resolve_credentials(
aws_access_key_id="explicit-key",
aws_secret_access_key="explicit-secret",
aws_session_token="explicit-token",
region="eu-west-1",
)
self.assertEqual(result.access_key, "explicit-key")
self.assertEqual(result.secret_key, "explicit-secret")
self.assertEqual(result.session_token, "explicit-token")
self.assertEqual(result.region, "eu-west-1")
@patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "env-key",
"AWS_SECRET_ACCESS_KEY": "env-secret",
"AWS_SESSION_TOKEN": "env-token",
"AWS_REGION": "ap-southeast-2",
},
)
def test_env_vars_fallback(self):
"""Environment variables are used when explicit params are None."""
result = resolve_credentials()
self.assertEqual(result.access_key, "env-key")
self.assertEqual(result.secret_key, "env-secret")
self.assertEqual(result.session_token, "env-token")
self.assertEqual(result.region, "ap-southeast-2")
@patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "env-key",
"AWS_SECRET_ACCESS_KEY": "env-secret",
},
)
def test_explicit_overrides_env(self):
"""Explicit params win over environment variables."""
result = resolve_credentials(
aws_access_key_id="override-key",
aws_secret_access_key="override-secret",
)
self.assertEqual(result.access_key, "override-key")
self.assertEqual(result.secret_key, "override-secret")
@patch.dict(os.environ, {}, clear=True)
def test_partial_explicit_credentials_do_not_mix_with_boto3_chain(self):
"""Partial explicit credentials are not completed from ambient boto3 credentials."""
mock_frozen = MagicMock()
mock_frozen.access_key = "boto3-key"
mock_frozen.secret_key = "boto3-secret"
mock_frozen.token = "boto3-token"
mock_creds = MagicMock()
mock_creds.get_frozen_credentials.return_value = mock_frozen
mock_session = MagicMock()
mock_session.get_credentials.return_value = mock_creds
mock_boto3 = MagicMock()
mock_boto3.Session.return_value = mock_session
with patch.dict("sys.modules", {"boto3": mock_boto3}):
result = resolve_credentials(aws_access_key_id="explicit-key")
self.assertEqual(result.access_key, "explicit-key")
self.assertIsNone(result.secret_key)
mock_boto3.Session.assert_not_called()
@patch.dict(os.environ, {}, clear=True)
def test_boto3_chain_fallback(self):
"""When no explicit creds or env vars, falls back to boto3 chain."""
mock_frozen = MagicMock()
mock_frozen.access_key = "boto3-key"
mock_frozen.secret_key = "boto3-secret"
mock_frozen.token = "boto3-token"
mock_creds = MagicMock()
mock_creds.get_frozen_credentials.return_value = mock_frozen
mock_session = MagicMock()
mock_session.get_credentials.return_value = mock_creds
mock_boto3 = MagicMock()
mock_boto3.Session.return_value = mock_session
# boto3 is imported inside resolve_credentials via `import boto3`,
# so we patch it in sys.modules.
with patch.dict("sys.modules", {"boto3": mock_boto3}):
result = resolve_credentials()
self.assertEqual(result.access_key, "boto3-key")
self.assertEqual(result.secret_key, "boto3-secret")
self.assertEqual(result.session_token, "boto3-token")
@patch.dict(os.environ, {}, clear=True)
def test_default_region(self):
"""Default region is us-east-1 when nothing is specified."""
result = resolve_credentials(
aws_access_key_id="key",
aws_secret_access_key="secret",
)
self.assertEqual(result.region, "us-east-1")
def test_returns_aws_credentials_dataclass(self):
"""Result is an AWSCredentials instance."""
result = resolve_credentials(
aws_access_key_id="key",
aws_secret_access_key="secret",
)
self.assertIsInstance(result, AWSCredentials)
@patch.dict(os.environ, {}, clear=True)
def test_none_when_no_credentials_available(self):
"""access_key and secret_key are None when nothing resolves."""
# Mock boto3 import to fail
with patch.dict("sys.modules", {"boto3": None}):
# Force re-import to hit the ImportError path
result = resolve_credentials()
# Since boto3 import will actually succeed (it's installed),
# but if no creds are configured, frozen creds may return None
# Just verify the function doesn't crash and returns AWSCredentials
self.assertIsInstance(result, AWSCredentials)
if __name__ == "__main__":
unittest.main()