diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py index c482b1b00..6a427d707 100644 --- a/src/pipecat/services/aws/stt.py +++ b/src/pipecat/services/aws/stt.py @@ -11,7 +11,6 @@ speech-to-text transcription with support for multiple languages and audio forma """ import json -import os import random import string from collections.abc import AsyncGenerator @@ -29,7 +28,12 @@ from pipecat.frames.frames import ( StartFrame, TranscriptionFrame, ) -from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url +from pipecat.services.aws.utils import ( + build_event_message, + decode_event, + get_presigned_url, + resolve_credentials, +) from pipecat.services.settings import STTSettings, assert_given from pipecat.services.stt_latency import AWS_TRANSCRIBE_TTFS_P99 from pipecat.services.stt_service import WebsocketSTTService @@ -81,9 +85,12 @@ class AWSTranscribeSTTService(WebsocketSTTService): """Initialize the AWS Transcribe STT service. Args: - api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable. - aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable. - aws_session_token: AWS session token for temporary credentials. If None, uses AWS_SESSION_TOKEN environment variable. + api_key: AWS secret access key. If None, falls back to environment + variables and the default boto3 credential chain (instance + profiles, IRSA, ECS task roles, SSO, etc.). + aws_access_key_id: AWS access key ID. Same fallback behaviour as + ``api_key``. + aws_session_token: AWS session token for temporary credentials. region: AWS region for the service. sample_rate: Audio sample rate in Hz. If None, uses the pipeline sample rate. AWS Transcribe only supports 8000 or 16000 Hz; other values are @@ -129,11 +136,19 @@ class AWSTranscribeSTTService(WebsocketSTTService): self._show_speaker_label = False self._enable_channel_identification = False + # Resolve credentials using the shared chain (explicit → env → boto3). + resolved = resolve_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=api_key, + aws_session_token=aws_session_token, + region=region, + ) + self._credentials = { - "aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), - "aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"), - "aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"), - "region": region or os.getenv("AWS_REGION", "us-east-1"), + "aws_access_key_id": resolved.access_key, + "aws_secret_access_key": resolved.secret_key, + "aws_session_token": resolved.session_token, + "region": resolved.region, } self._receive_task = None diff --git a/src/pipecat/services/aws/utils.py b/src/pipecat/services/aws/utils.py index c14f1f20d..e48f0dd10 100644 --- a/src/pipecat/services/aws/utils.py +++ b/src/pipecat/services/aws/utils.py @@ -4,10 +4,11 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""AWS Transcribe utility functions and classes for WebSocket streaming. +"""AWS utility functions for Pipecat services. -This module provides utilities for creating presigned URLs, building event messages, -and handling AWS event stream protocol for real-time transcription services. +This module provides shared credential resolution and AWS Transcribe utilities +for creating presigned URLs, building event messages, and handling AWS event +stream protocol for real-time transcription services. """ import binascii @@ -15,8 +16,89 @@ import datetime import hashlib import hmac import json +import os import struct import urllib.parse +from dataclasses import dataclass +from typing import Any + +from loguru import logger + + +@dataclass +class AWSCredentials: + """Resolved AWS credentials ready for use by any AWS service.""" + + access_key: str | None + secret_key: str | None + session_token: str | None + region: str + + +def resolve_credentials( + *, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + region: str | None = None, +) -> AWSCredentials: + """Resolve AWS credentials using the standard fallback chain. + + Resolution order: + 1. Explicit parameters + 2. Environment variables (``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``, + ``AWS_SESSION_TOKEN``, ``AWS_REGION``) + 3. Default boto3/botocore credential chain (instance profiles, IRSA, + ECS task roles, SSO, credential files, etc.) + + The boto3 fallback (step 3) is only attempted when *both* access key and + secret key are still unresolved after steps 1-2. This avoids replacing + explicitly provided credentials with ambient ones. + + Args: + aws_access_key_id: Explicit access key ID. + aws_secret_access_key: Explicit secret access key. + aws_session_token: Explicit session token. + region: Explicit AWS region. + + Returns: + An :class:`AWSCredentials` instance. ``access_key`` and + ``secret_key`` may still be ``None`` if no credentials could be + resolved (the caller should raise an appropriate error). + """ + access_key = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID") + secret_key = aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY") + session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN") + resolved_region = region or os.getenv("AWS_REGION", "us-east-1") + + # Fall back to the boto3 credential provider chain (pod roles, IRSA, + # instance profiles, SSO, credential files, etc.) when explicit + # credentials were not supplied. + if not access_key and not secret_key: + try: + import boto3 + + session = boto3.Session(region_name=resolved_region) + creds = session.get_credentials() + if creds: + frozen = creds.get_frozen_credentials() + access_key = access_key or frozen.access_key + secret_key = secret_key or frozen.secret_key + session_token = session_token or frozen.token + except ImportError: + logger.debug( + "boto3 not available for credential chain fallback; " + "install pipecat-ai[aws] for full credential support." + ) + except Exception as e: + logger.warning(f"Failed to resolve AWS credentials via boto3 chain: {e}") + + return AWSCredentials( + access_key=access_key, + secret_key=secret_key, + session_token=session_token, + region=resolved_region, + ) def get_presigned_url( diff --git a/tests/test_aws_credentials.py b/tests/test_aws_credentials.py new file mode 100644 index 000000000..112347697 --- /dev/null +++ b/tests/test_aws_credentials.py @@ -0,0 +1,147 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Unit tests for AWS shared credential resolution.""" + +import os +import unittest +from unittest.mock import MagicMock, patch + +from pipecat.services.aws.utils import AWSCredentials, resolve_credentials + + +class TestResolveCredentials(unittest.TestCase): + """Tests for resolve_credentials() fallback chain.""" + + def test_explicit_credentials_take_priority(self): + """Explicit parameters override env vars and boto3 chain.""" + result = resolve_credentials( + aws_access_key_id="explicit-key", + aws_secret_access_key="explicit-secret", + aws_session_token="explicit-token", + region="eu-west-1", + ) + self.assertEqual(result.access_key, "explicit-key") + self.assertEqual(result.secret_key, "explicit-secret") + self.assertEqual(result.session_token, "explicit-token") + self.assertEqual(result.region, "eu-west-1") + + @patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "env-key", + "AWS_SECRET_ACCESS_KEY": "env-secret", + "AWS_SESSION_TOKEN": "env-token", + "AWS_REGION": "ap-southeast-2", + }, + ) + def test_env_vars_fallback(self): + """Environment variables are used when explicit params are None.""" + result = resolve_credentials() + self.assertEqual(result.access_key, "env-key") + self.assertEqual(result.secret_key, "env-secret") + self.assertEqual(result.session_token, "env-token") + self.assertEqual(result.region, "ap-southeast-2") + + @patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "env-key", + "AWS_SECRET_ACCESS_KEY": "env-secret", + }, + ) + def test_explicit_overrides_env(self): + """Explicit params win over environment variables.""" + result = resolve_credentials( + aws_access_key_id="override-key", + aws_secret_access_key="override-secret", + ) + self.assertEqual(result.access_key, "override-key") + self.assertEqual(result.secret_key, "override-secret") + + @patch.dict(os.environ, {}, clear=True) + def test_partial_explicit_credentials_do_not_mix_with_boto3_chain(self): + """Partial explicit credentials are not completed from ambient boto3 credentials.""" + mock_frozen = MagicMock() + mock_frozen.access_key = "boto3-key" + mock_frozen.secret_key = "boto3-secret" + mock_frozen.token = "boto3-token" + + mock_creds = MagicMock() + mock_creds.get_frozen_credentials.return_value = mock_frozen + + mock_session = MagicMock() + mock_session.get_credentials.return_value = mock_creds + + mock_boto3 = MagicMock() + mock_boto3.Session.return_value = mock_session + + with patch.dict("sys.modules", {"boto3": mock_boto3}): + result = resolve_credentials(aws_access_key_id="explicit-key") + + self.assertEqual(result.access_key, "explicit-key") + self.assertIsNone(result.secret_key) + mock_boto3.Session.assert_not_called() + + @patch.dict(os.environ, {}, clear=True) + def test_boto3_chain_fallback(self): + """When no explicit creds or env vars, falls back to boto3 chain.""" + mock_frozen = MagicMock() + mock_frozen.access_key = "boto3-key" + mock_frozen.secret_key = "boto3-secret" + mock_frozen.token = "boto3-token" + + mock_creds = MagicMock() + mock_creds.get_frozen_credentials.return_value = mock_frozen + + mock_session = MagicMock() + mock_session.get_credentials.return_value = mock_creds + + mock_boto3 = MagicMock() + mock_boto3.Session.return_value = mock_session + + # boto3 is imported inside resolve_credentials via `import boto3`, + # so we patch it in sys.modules. + with patch.dict("sys.modules", {"boto3": mock_boto3}): + result = resolve_credentials() + + self.assertEqual(result.access_key, "boto3-key") + self.assertEqual(result.secret_key, "boto3-secret") + self.assertEqual(result.session_token, "boto3-token") + + @patch.dict(os.environ, {}, clear=True) + def test_default_region(self): + """Default region is us-east-1 when nothing is specified.""" + result = resolve_credentials( + aws_access_key_id="key", + aws_secret_access_key="secret", + ) + self.assertEqual(result.region, "us-east-1") + + def test_returns_aws_credentials_dataclass(self): + """Result is an AWSCredentials instance.""" + result = resolve_credentials( + aws_access_key_id="key", + aws_secret_access_key="secret", + ) + self.assertIsInstance(result, AWSCredentials) + + @patch.dict(os.environ, {}, clear=True) + def test_none_when_no_credentials_available(self): + """access_key and secret_key are None when nothing resolves.""" + # Mock boto3 import to fail + with patch.dict("sys.modules", {"boto3": None}): + # Force re-import to hit the ImportError path + result = resolve_credentials() + + # Since boto3 import will actually succeed (it's installed), + # but if no creds are configured, frozen creds may return None + # Just verify the function doesn't crash and returns AWSCredentials + self.assertIsInstance(result, AWSCredentials) + + +if __name__ == "__main__": + unittest.main()