AWS Transcribe STT previously only supported credentials via explicit parameters or environment variables. Services running with IAM roles (EKS pod roles, IRSA, ECS task roles, EC2 instance profiles) or SSO couldn't use Transcribe without exporting static credentials. Changes: - Add resolve_credentials() to utils.py providing a standard fallback chain: explicit params → environment variables → boto3 credential provider chain (instance profiles, IRSA, pod roles, SSO, etc.) - Add AWSCredentials dataclass for type-safe credential passing - Update AWSTranscribeSTTService to use resolve_credentials() instead of manual os.getenv() calls - The boto3 fallback is only attempted when both access key and secret key are unresolved, avoiding replacement of explicitly provided creds - boto3 is imported lazily inside the function to avoid hard dependency for services that don't need the fallback chain - Add 7 unit tests covering the credential resolution chain The Bedrock LLM and Polly TTS services already support the full credential chain via aioboto3.Session() and are not modified. Related to #4197
148 lines
5.3 KiB
Python
148 lines
5.3 KiB
Python
#
|
|
# Copyright (c) 2024-2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
"""Unit tests for AWS shared credential resolution."""
|
|
|
|
import os
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from pipecat.services.aws.utils import AWSCredentials, resolve_credentials
|
|
|
|
|
|
class TestResolveCredentials(unittest.TestCase):
|
|
"""Tests for resolve_credentials() fallback chain."""
|
|
|
|
def test_explicit_credentials_take_priority(self):
|
|
"""Explicit parameters override env vars and boto3 chain."""
|
|
result = resolve_credentials(
|
|
aws_access_key_id="explicit-key",
|
|
aws_secret_access_key="explicit-secret",
|
|
aws_session_token="explicit-token",
|
|
region="eu-west-1",
|
|
)
|
|
self.assertEqual(result.access_key, "explicit-key")
|
|
self.assertEqual(result.secret_key, "explicit-secret")
|
|
self.assertEqual(result.session_token, "explicit-token")
|
|
self.assertEqual(result.region, "eu-west-1")
|
|
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"AWS_ACCESS_KEY_ID": "env-key",
|
|
"AWS_SECRET_ACCESS_KEY": "env-secret",
|
|
"AWS_SESSION_TOKEN": "env-token",
|
|
"AWS_REGION": "ap-southeast-2",
|
|
},
|
|
)
|
|
def test_env_vars_fallback(self):
|
|
"""Environment variables are used when explicit params are None."""
|
|
result = resolve_credentials()
|
|
self.assertEqual(result.access_key, "env-key")
|
|
self.assertEqual(result.secret_key, "env-secret")
|
|
self.assertEqual(result.session_token, "env-token")
|
|
self.assertEqual(result.region, "ap-southeast-2")
|
|
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"AWS_ACCESS_KEY_ID": "env-key",
|
|
"AWS_SECRET_ACCESS_KEY": "env-secret",
|
|
},
|
|
)
|
|
def test_explicit_overrides_env(self):
|
|
"""Explicit params win over environment variables."""
|
|
result = resolve_credentials(
|
|
aws_access_key_id="override-key",
|
|
aws_secret_access_key="override-secret",
|
|
)
|
|
self.assertEqual(result.access_key, "override-key")
|
|
self.assertEqual(result.secret_key, "override-secret")
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_partial_explicit_credentials_do_not_mix_with_boto3_chain(self):
|
|
"""Partial explicit credentials are not completed from ambient boto3 credentials."""
|
|
mock_frozen = MagicMock()
|
|
mock_frozen.access_key = "boto3-key"
|
|
mock_frozen.secret_key = "boto3-secret"
|
|
mock_frozen.token = "boto3-token"
|
|
|
|
mock_creds = MagicMock()
|
|
mock_creds.get_frozen_credentials.return_value = mock_frozen
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.get_credentials.return_value = mock_creds
|
|
|
|
mock_boto3 = MagicMock()
|
|
mock_boto3.Session.return_value = mock_session
|
|
|
|
with patch.dict("sys.modules", {"boto3": mock_boto3}):
|
|
result = resolve_credentials(aws_access_key_id="explicit-key")
|
|
|
|
self.assertEqual(result.access_key, "explicit-key")
|
|
self.assertIsNone(result.secret_key)
|
|
mock_boto3.Session.assert_not_called()
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_boto3_chain_fallback(self):
|
|
"""When no explicit creds or env vars, falls back to boto3 chain."""
|
|
mock_frozen = MagicMock()
|
|
mock_frozen.access_key = "boto3-key"
|
|
mock_frozen.secret_key = "boto3-secret"
|
|
mock_frozen.token = "boto3-token"
|
|
|
|
mock_creds = MagicMock()
|
|
mock_creds.get_frozen_credentials.return_value = mock_frozen
|
|
|
|
mock_session = MagicMock()
|
|
mock_session.get_credentials.return_value = mock_creds
|
|
|
|
mock_boto3 = MagicMock()
|
|
mock_boto3.Session.return_value = mock_session
|
|
|
|
# boto3 is imported inside resolve_credentials via `import boto3`,
|
|
# so we patch it in sys.modules.
|
|
with patch.dict("sys.modules", {"boto3": mock_boto3}):
|
|
result = resolve_credentials()
|
|
|
|
self.assertEqual(result.access_key, "boto3-key")
|
|
self.assertEqual(result.secret_key, "boto3-secret")
|
|
self.assertEqual(result.session_token, "boto3-token")
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_default_region(self):
|
|
"""Default region is us-east-1 when nothing is specified."""
|
|
result = resolve_credentials(
|
|
aws_access_key_id="key",
|
|
aws_secret_access_key="secret",
|
|
)
|
|
self.assertEqual(result.region, "us-east-1")
|
|
|
|
def test_returns_aws_credentials_dataclass(self):
|
|
"""Result is an AWSCredentials instance."""
|
|
result = resolve_credentials(
|
|
aws_access_key_id="key",
|
|
aws_secret_access_key="secret",
|
|
)
|
|
self.assertIsInstance(result, AWSCredentials)
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_none_when_no_credentials_available(self):
|
|
"""access_key and secret_key are None when nothing resolves."""
|
|
# Mock boto3 import to fail
|
|
with patch.dict("sys.modules", {"boto3": None}):
|
|
# Force re-import to hit the ImportError path
|
|
result = resolve_credentials()
|
|
|
|
# Since boto3 import will actually succeed (it's installed),
|
|
# but if no creds are configured, frozen creds may return None
|
|
# Just verify the function doesn't crash and returns AWSCredentials
|
|
self.assertIsInstance(result, AWSCredentials)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|