diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e330b9c9..2eec61bca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added new AWS services `AWSBedrockLLMService` and `AWSTranscribeSTTService`. + - Added `on_active_speaker_changed` event handler to the `DailyTransport` class. - Added `enable_ssml_parsing` and `enable_logging` to `InputParams` in @@ -25,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated +- `PollyTTSService` is now deprecated, use `AWSPollyTTSService` instead. + - Observer `on_push_frame(src, dst, frame, direction, timestamp)` is now deprecated, use `on_push_frame(data: FramePushed)` instead. diff --git a/README.md b/README.md index 47be9b6e1..ec3b0a791 100644 --- a/README.md +++ b/README.md @@ -49,18 +49,18 @@ You can connect to Pipecat from any platform using our official SDKs: ## 🧩 Available services -| Category | Services | -| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | -| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | -| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | -| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | -| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | -| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | -| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) | -| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | -| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | -| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | +| Category | Services | +|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | +| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | +| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | +| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | +| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local | +| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | +| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) | +| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) | +| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | +| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | 📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services) diff --git a/examples/foundational/07m-interruptible-polly.py b/examples/foundational/07m-interruptible-aws.py similarity index 79% rename from examples/foundational/07m-interruptible-polly.py rename to examples/foundational/07m-interruptible-aws.py index 286fe5128..bbcfe7313 100644 --- a/examples/foundational/07m-interruptible-polly.py +++ b/examples/foundational/07m-interruptible-aws.py @@ -5,7 +5,6 @@ # import argparse -import os from dotenv import load_dotenv from loguru import logger @@ -15,9 +14,9 @@ from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.aws.tts import PollyTTSService -from pipecat.services.deepgram.stt import DeepgramSTTService -from pipecat.services.openai.llm import OpenAILLMService +from pipecat.services.aws.llm import AWSBedrockLLMService +from pipecat.services.aws.stt import AWSTranscribeSTTService +from pipecat.services.aws.tts import AWSPollyTTSService from pipecat.transports.base_transport import TransportParams from pipecat.transports.network.small_webrtc import SmallWebRTCTransport from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection @@ -37,17 +36,19 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac ), ) - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + stt = AWSTranscribeSTTService() - tts = PollyTTSService( - api_key=os.getenv("AWS_SECRET_ACCESS_KEY"), - aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), - region=os.getenv("AWS_REGION"), - voice_id="Amy", - params=PollyTTSService.InputParams(engine="neural", language="en-GB", rate="1.05"), + tts = AWSPollyTTSService( + region="us-west-2", # only specific regions support generative TTS + voice_id="Joanna", + params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"), ) - llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) + llm = AWSBedrockLLMService( + aws_region="us-west-2", + model="us.anthropic.claude-3-5-haiku-20241022-v1:0", + params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"), + ) messages = [ { @@ -85,7 +86,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac async def on_client_connected(transport, client): logger.info(f"Client connected") # Kick off the conversation. - messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + messages.append({"role": "user", "content": "Please introduce yourself to the user."}) await task.queue_frames([context_aggregator.user().get_context_frame()]) @transport.event_handler("on_client_disconnected") diff --git a/examples/foundational/14r-function-calling-aws.py b/examples/foundational/14r-function-calling-aws.py new file mode 100644 index 000000000..cf4859576 --- /dev/null +++ b/examples/foundational/14r-function-calling-aws.py @@ -0,0 +1,139 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import argparse +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.aws.llm import AWSBedrockLLMService +from pipecat.services.aws.stt import AWSTranscribeSTTService +from pipecat.services.aws.tts import AWSPollyTTSService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.transports.base_transport import TransportParams +from pipecat.transports.network.small_webrtc import SmallWebRTCTransport +from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection + +load_dotenv(override=True) + + +async def fetch_weather_from_api(params: FunctionCallParams): + await params.result_callback({"conditions": "nice", "temperature": "75"}) + + +async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace): + logger.info(f"Starting bot") + + transport = SmallWebRTCTransport( + webrtc_connection=webrtc_connection, + params=TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + stt = AWSTranscribeSTTService() + + tts = AWSPollyTTSService( + region="us-west-2", # only specific regions support generative TTS + voice_id="Joanna", + params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"), + ) + + llm = AWSBedrockLLMService( + aws_region="us-west-2", + model="us.anthropic.claude-3-5-haiku-20241022-v1:0", + params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"), + ) + + # You can also register a function_name of None to get all functions + # sent to the same callback with an additional function_name parameter. + llm.register_function("get_current_weather", fetch_weather_from_api) + + weather_function = FunctionSchema( + name="get_current_weather", + description="Get the current weather", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location.", + }, + }, + required=["location", "format"], + ) + tools = ToolsSchema(standard_tools=[weather_function]) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), + stt, + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + ), + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + messages.append({"role": "user", "content": "Please introduce yourself to the user."}) + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + + @transport.event_handler("on_client_closed") + async def on_client_closed(transport, client): + logger.info(f"Client closed connection") + await task.cancel() + + runner = PipelineRunner(handle_sigint=False) + + await runner.run(task) + + +if __name__ == "__main__": + from run import main + + main() diff --git a/pyproject.toml b/pyproject.toml index 910c8d066..13305933b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ Website = "https://pipecat.ai" [project.optional-dependencies] anthropic = [ "anthropic~=0.49.0" ] assemblyai = [ "assemblyai~=0.37.0" ] -aws = [ "boto3~=1.37.16" ] +aws = [ "boto3~=1.37.16", "websockets~=13.1" ] azure = [ "azure-cognitiveservices-speech~=1.42.0"] cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ] cerebras = [] diff --git a/src/pipecat/adapters/services/anthropic_adapter.py b/src/pipecat/adapters/services/anthropic_adapter.py index a699469d3..23197d3a8 100644 --- a/src/pipecat/adapters/services/anthropic_adapter.py +++ b/src/pipecat/adapters/services/anthropic_adapter.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from pipecat.adapters.base_llm_adapter import BaseLLMAdapter from pipecat.adapters.schemas.function_schema import FunctionSchema diff --git a/src/pipecat/adapters/services/bedrock_adapter.py b/src/pipecat/adapters/services/bedrock_adapter.py new file mode 100644 index 000000000..113a6938d --- /dev/null +++ b/src/pipecat/adapters/services/bedrock_adapter.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Any, Dict, List + +from pipecat.adapters.base_llm_adapter import BaseLLMAdapter +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema + + +class AWSBedrockLLMAdapter(BaseLLMAdapter): + @staticmethod + def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]: + return { + "toolSpec": { + "name": function.name, + "description": function.description, + "inputSchema": { + "json": { + "type": "object", + "properties": function.properties, + "required": function.required, + }, + }, + } + } + + def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]: + """Converts function schemas to Bedrock's function-calling format. + + :return: Bedrock formatted function call definition. + """ + + functions_schema = tools_schema.standard_tools + return [self._to_bedrock_function_format(func) for func in functions_schema] diff --git a/src/pipecat/services/aws/__init__.py b/src/pipecat/services/aws/__init__.py index b36c88499..b1f157bd3 100644 --- a/src/pipecat/services/aws/__init__.py +++ b/src/pipecat/services/aws/__init__.py @@ -8,6 +8,8 @@ import sys from pipecat.services import DeprecatedModuleProxy +from .llm import * +from .stt import * from .tts import * -sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts") +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.[llm,stt,tts]") diff --git a/src/pipecat/services/aws/llm.py b/src/pipecat/services/aws/llm.py new file mode 100644 index 000000000..921d3c790 --- /dev/null +++ b/src/pipecat/services/aws/llm.py @@ -0,0 +1,785 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import base64 +import copy +import io +import json +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from loguru import logger +from PIL import Image +from pydantic import BaseModel, Field + +from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter +from pipecat.frames.frames import ( + Frame, + FunctionCallCancelFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMTextFrame, + LLMUpdateSettingsFrame, + UserImageRawFrame, + VisionImageRawFrame, +) +from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, + LLMAssistantContextAggregator, + LLMUserAggregatorParams, + LLMUserContextAggregator, +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.llm_service import LLMService + +try: + import boto3 + import httpx + from botocore.config import Config +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +@dataclass +class AWSBedrockContextAggregatorPair: + _user: "AWSBedrockUserContextAggregator" + _assistant: "AWSBedrockAssistantContextAggregator" + + def user(self) -> "AWSBedrockUserContextAggregator": + return self._user + + def assistant(self) -> "AWSBedrockAssistantContextAggregator": + return self._assistant + + +class AWSBedrockLLMContext(OpenAILLMContext): + def __init__( + self, + messages: Optional[List[dict]] = None, + tools: Optional[List[dict]] = None, + tool_choice: Optional[dict] = None, + *, + system: Optional[str] = None, + ): + super().__init__(messages=messages, tools=tools, tool_choice=tool_choice) + self.system = system + + @staticmethod + def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext": + logger.debug(f"Upgrading to AWS Bedrock: {obj}") + if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext): + obj.__class__ = AWSBedrockLLMContext + obj._restructure_from_openai_messages() + else: + obj._restructure_from_bedrock_messages() + return obj + + @classmethod + def from_openai_context(cls, openai_context: OpenAILLMContext): + self = cls( + messages=openai_context.messages, + tools=openai_context.tools, + tool_choice=openai_context.tool_choice, + ) + self.set_llm_adapter(openai_context.get_llm_adapter()) + self._restructure_from_openai_messages() + return self + + @classmethod + def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext": + self = cls(messages=messages) + self._restructure_from_openai_messages() + return self + + @classmethod + def from_image_frame(cls, frame: VisionImageRawFrame) -> "AWSBedrockLLMContext": + context = cls() + context.add_image_frame_message( + format=frame.format, size=frame.size, image=frame.image, text=frame.text + ) + return context + + def set_messages(self, messages: List): + self._messages[:] = messages + self._restructure_from_openai_messages() + + # convert a message in AWS Bedrock format into one or more messages in OpenAI format + def to_standard_messages(self, obj): + """Convert AWS Bedrock message format to standard structured format. + + Handles text content and function calls for both user and assistant messages. + + Args: + obj: Message in AWS Bedrock format: + { + "role": "user/assistant", + "content": [{"text": str} | {"toolUse": {...}} | {"toolResult": {...}}] + } + + Returns: + List of messages in standard format: + [ + { + "role": "user/assistant/tool", + "content": [{"type": "text", "text": str}] + } + ] + """ + role = obj.get("role") + content = obj.get("content") + + if role == "assistant": + if isinstance(content, str): + return [{"role": role, "content": [{"type": "text", "text": content}]}] + elif isinstance(content, list): + text_items = [] + tool_items = [] + for item in content: + if "text" in item: + text_items.append({"type": "text", "text": item["text"]}) + elif "toolUse" in item: + tool_use = item["toolUse"] + tool_items.append( + { + "type": "function", + "id": tool_use["toolUseId"], + "function": { + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + }, + } + ) + messages = [] + if text_items: + messages.append({"role": role, "content": text_items}) + if tool_items: + messages.append({"role": role, "tool_calls": tool_items}) + return messages + elif role == "user": + if isinstance(content, str): + return [{"role": role, "content": [{"type": "text", "text": content}]}] + elif isinstance(content, list): + text_items = [] + tool_items = [] + for item in content: + if "text" in item: + text_items.append({"type": "text", "text": item["text"]}) + elif "toolResult" in item: + tool_result = item["toolResult"] + # Extract content from toolResult + result_content = "" + if isinstance(tool_result["content"], list): + for content_item in tool_result["content"]: + if "text" in content_item: + result_content = content_item["text"] + elif "json" in content_item: + result_content = json.dumps(content_item["json"]) + else: + result_content = tool_result["content"] + + tool_items.append( + { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": result_content, + } + ) + messages = [] + if text_items: + messages.append({"role": role, "content": text_items}) + messages.extend(tool_items) + return messages + + def from_standard_message(self, message): + """Convert standard format message to AWS Bedrock format. + + Handles conversion of text content, tool calls, and tool results. + Empty text content is converted to "(empty)". + + Args: + message: Message in standard format: + { + "role": "user/assistant/tool", + "content": str | [{"type": "text", ...}], + "tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}] + } + + Returns: + Message in AWS Bedrock format: + { + "role": "user/assistant", + "content": [ + {"text": str} | + {"toolUse": {"toolUseId": str, "name": str, "input": dict}} | + {"toolResult": {"toolUseId": str, "content": [...], "status": str}} + ] + } + """ + if message["role"] == "tool": + # Try to parse the content as JSON if it looks like JSON + try: + if message["content"].strip().startswith("{") and message[ + "content" + ].strip().endswith("}"): + content_json = json.loads(message["content"]) + tool_result_content = [{"json": content_json}] + else: + tool_result_content = [{"text": message["content"]}] + except: + tool_result_content = [{"text": message["content"]}] + + return { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": message["tool_call_id"], + "content": tool_result_content, + }, + }, + ], + } + + if message.get("tool_calls"): + tc = message["tool_calls"] + ret = {"role": "assistant", "content": []} + for tool_call in tc: + function = tool_call["function"] + arguments = json.loads(function["arguments"]) + new_tool_use = { + "toolUse": { + "toolUseId": tool_call["id"], + "name": function["name"], + "input": arguments, + } + } + ret["content"].append(new_tool_use) + return ret + + # Handle text content + content = message.get("content") + if isinstance(content, str): + if content == "": + return {"role": message["role"], "content": [{"text": "(empty)"}]} + else: + return {"role": message["role"], "content": [{"text": content}]} + elif isinstance(content, list): + new_content = [] + for item in content: + if item.get("type", "") == "text": + text_content = item["text"] if item["text"] != "" else "(empty)" + new_content.append({"text": text_content}) + return {"role": message["role"], "content": new_content} + + return message + + def add_image_frame_message( + self, *, format: str, size: tuple[int, int], image: bytes, text: str = None + ): + buffer = io.BytesIO() + Image.frombytes(format, size, image).save(buffer, format="JPEG") + encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") + + # Image should be the first content block in the message + content = [{"type": "image", "format": "jpeg", "source": {"bytes": encoded_image}}] + if text: + content.append({"text": text}) + self.add_message({"role": "user", "content": content}) + + def add_message(self, message): + try: + if self.messages: + # AWS Bedrock requires that roles alternate. If this message's + # role is the same as the last message, we should add this + # message's content to the last message. + if self.messages[-1]["role"] == message["role"]: + # if the last message has just a content string, convert it to a list + # in the proper format + if isinstance(self.messages[-1]["content"], str): + self.messages[-1]["content"] = [{"text": self.messages[-1]["content"]}] + # if this message has just a content string, convert it to a list + # in the proper format + if isinstance(message["content"], str): + message["content"] = [{"text": message["content"]}] + # append the content of this message to the last message + self.messages[-1]["content"].extend(message["content"]) + else: + self.messages.append(message) + else: + self.messages.append(message) + except Exception as e: + logger.error(f"Error adding message: {e}") + + def _restructure_from_bedrock_messages(self): + """Restructure messages in AWS Bedrock format by handling system + messages, merging consecutive messages with the same role, and ensuring + proper content formatting. + + """ + # Handle system message if present at the beginning + if self.messages and self.messages[0]["role"] == "system": + if len(self.messages) == 1: + self.messages[0]["role"] = "user" + else: + system_content = self.messages.pop(0)["content"] + if isinstance(system_content, str): + system_content = [{"text": system_content}] + + if self.system: + if isinstance(self.system, str): + self.system = [{"text": self.system}] + self.system.extend(system_content) + else: + self.system = system_content + + # Ensure content is properly formatted + for msg in self.messages: + if isinstance(msg["content"], str): + msg["content"] = [{"text": msg["content"]}] + elif not msg["content"]: + msg["content"] = [{"text": "(empty)"}] + elif isinstance(msg["content"], list): + for idx, item in enumerate(msg["content"]): + if isinstance(item, dict) and "text" in item and item["text"] == "": + item["text"] = "(empty)" + elif isinstance(item, str) and item == "": + msg["content"][idx] = {"text": "(empty)"} + + # Merge consecutive messages with the same role + merged_messages = [] + for msg in self.messages: + if merged_messages and merged_messages[-1]["role"] == msg["role"]: + merged_messages[-1]["content"].extend(msg["content"]) + else: + merged_messages.append(msg) + + self.messages.clear() + self.messages.extend(merged_messages) + + def _restructure_from_openai_messages(self): + # first, map across self._messages calling self.from_standard_message(m) to modify messages in place + try: + self._messages[:] = [self.from_standard_message(m) for m in self._messages] + except Exception as e: + logger.error(f"Error mapping messages: {e}") + + # See if we should pull the system message out of our context.messages list. (For + # compatibility with Open AI messages format.) + if self.messages and self.messages[0]["role"] == "system": + self.system = self.messages[0]["content"] + self.messages.pop(0) + + # Merge consecutive messages with the same role. + i = 0 + while i < len(self.messages) - 1: + current_message = self.messages[i] + next_message = self.messages[i + 1] + if current_message["role"] == next_message["role"]: + # Convert content to list of dictionaries if it's a string + if isinstance(current_message["content"], str): + current_message["content"] = [ + {"type": "text", "text": current_message["content"]} + ] + if isinstance(next_message["content"], str): + next_message["content"] = [{"type": "text", "text": next_message["content"]}] + # Concatenate the content + current_message["content"].extend(next_message["content"]) + # Remove the next message from the list + self.messages.pop(i + 1) + else: + i += 1 + + # Avoid empty content in messages + for message in self.messages: + if isinstance(message["content"], str) and message["content"] == "": + message["content"] = "(empty)" + elif isinstance(message["content"], list) and len(message["content"]) == 0: + message["content"] = [{"type": "text", "text": "(empty)"}] + + def get_messages_for_persistent_storage(self): + messages = super().get_messages_for_persistent_storage() + if self.system: + messages.insert(0, {"role": "system", "content": self.system}) + return messages + + def get_messages_for_logging(self) -> str: + msgs = [] + for message in self.messages: + msg = copy.deepcopy(message) + if "content" in msg: + if isinstance(msg["content"], list): + for item in msg["content"]: + if item.get("image"): + item["source"]["bytes"] = "..." + msgs.append(msg) + return json.dumps(msgs) + + +class AWSBedrockUserContextAggregator(LLMUserContextAggregator): + pass + + +class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator): + async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): + # Format tool use according to AWS Bedrock API + self._context.add_message( + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": frame.tool_call_id, + "name": frame.function_name, + "input": frame.arguments if frame.arguments else {}, + } + } + ], + } + ) + self._context.add_message( + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": frame.tool_call_id, + "content": [{"text": "IN_PROGRESS"}], + } + } + ], + } + ) + + async def handle_function_call_result(self, frame: FunctionCallResultFrame): + if frame.result: + result = json.dumps(frame.result) + await self._update_function_call_result(frame.function_name, frame.tool_call_id, result) + else: + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, "COMPLETED" + ) + + async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, "CANCELLED" + ) + + async def _update_function_call_result( + self, function_name: str, tool_call_id: str, result: Any + ): + for message in self._context.messages: + if message["role"] == "user": + for content in message["content"]: + if ( + isinstance(content, dict) + and content.get("toolResult") + and content["toolResult"]["toolUseId"] == tool_call_id + ): + content["toolResult"]["content"] = [{"text": result}] + + async def handle_user_image_frame(self, frame: UserImageRawFrame): + await self._update_function_call_result( + frame.request.function_name, frame.request.tool_call_id, "COMPLETED" + ) + self._context.add_image_frame_message( + format=frame.format, + size=frame.size, + image=frame.image, + text=frame.request.context, + ) + + +class AWSBedrockLLMService(LLMService): + """This class implements inference with AWS Bedrock models including Amazon + Nova and Anthropic Claude. + + Requires AWS credentials to be configured in the environment or through + boto3 configuration. + + """ + + # Overriding the default adapter to use the Anthropic one. + adapter_class = AWSBedrockLLMAdapter + + class InputParams(BaseModel): + max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1) + temperature: Optional[float] = Field(default_factory=lambda: 0.7, ge=0.0, le=1.0) + top_p: Optional[float] = Field(default_factory=lambda: 0.999, ge=0.0, le=1.0) + stop_sequences: Optional[List[str]] = Field(default_factory=lambda: []) + latency: Optional[str] = Field(default_factory=lambda: "standard") + additional_model_request_fields: Optional[Dict[str, Any]] = Field(default_factory=dict) + + def __init__( + self, + *, + aws_access_key: Optional[str] = None, + aws_secret_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region: str = "us-east-1", + model: str, + params: InputParams = InputParams(), + client_config: Optional[Config] = None, + **kwargs, + ): + super().__init__(**kwargs) + + # Initialize the AWS Bedrock client + if not client_config: + client_config = Config( + connect_timeout=300, # 5 minutes + read_timeout=300, # 5 minutes + retries={"max_attempts": 3}, + ) + session = boto3.Session( + aws_access_key_id=aws_access_key, + aws_secret_access_key=aws_secret_key, + aws_session_token=aws_session_token, + region_name=aws_region, + ) + self._client = session.client(service_name="bedrock-runtime", config=client_config) + + self.set_model_name(model) + self._settings = { + "max_tokens": params.max_tokens, + "temperature": params.temperature, + "top_p": params.top_p, + "latency": params.latency, + "additional_model_request_fields": params.additional_model_request_fields + if isinstance(params.additional_model_request_fields, dict) + else {}, + } + + logger.info(f"Using AWS Bedrock model: {model}") + + def can_generate_metrics(self) -> bool: + return True + + def create_context_aggregator( + self, + context: OpenAILLMContext, + *, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), + ) -> AWSBedrockContextAggregatorPair: + """Create an instance of AWSBedrockContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_params (LLMUserAggregatorParams, optional): User aggregator + parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User + aggregator parameters. + + Returns: + AWSBedrockContextAggregatorPair: A pair of context aggregators, one + for the user and one for the assistant, encapsulated in an + AWSBedrockContextAggregatorPair. + """ + context.set_llm_adapter(self.get_llm_adapter()) + + if isinstance(context, OpenAILLMContext): + context = AWSBedrockLLMContext.from_openai_context(context) + + user = AWSBedrockUserContextAggregator(context, params=user_params) + assistant = AWSBedrockAssistantContextAggregator(context, params=assistant_params) + return AWSBedrockContextAggregatorPair(_user=user, _assistant=assistant) + + async def _process_context(self, context: AWSBedrockLLMContext): + # Usage tracking + prompt_tokens = 0 + completion_tokens = 0 + completion_tokens_estimate = 0 + cache_read_input_tokens = 0 + cache_creation_input_tokens = 0 + use_completion_tokens_estimate = False + + try: + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + + await self.start_ttfb_metrics() + + # Set up inference config + inference_config = { + "maxTokens": self._settings["max_tokens"], + "temperature": self._settings["temperature"], + "topP": self._settings["top_p"], + } + + # Prepare request parameters + request_params = { + "modelId": self.model_name, + "messages": context.messages, + "inferenceConfig": inference_config, + "additionalModelRequestFields": self._settings["additional_model_request_fields"], + } + + # Add system message + request_params["system"] = context.system + + # Add tools if present + if context.tools: + tool_config = {"tools": context.tools} + + # Add tool_choice if specified + if context.tool_choice: + if context.tool_choice == "auto": + tool_config["toolChoice"] = {"auto": {}} + elif context.tool_choice == "none": + # Skip adding toolChoice for "none" + pass + elif ( + isinstance(context.tool_choice, dict) and "function" in context.tool_choice + ): + tool_config["toolChoice"] = { + "tool": {"name": context.tool_choice["function"]["name"]} + } + + request_params["toolConfig"] = tool_config + + # Add performance config if latency is specified + if self._settings["latency"] in ["standard", "optimized"]: + request_params["performanceConfig"] = {"latency": self._settings["latency"]} + + logger.debug(f"Calling AWS Bedrock model with: {request_params}") + + # Call AWS Bedrock with streaming + response = self._client.converse_stream(**request_params) + + await self.stop_ttfb_metrics() + + # Process the streaming response + tool_use_block = None + json_accumulator = "" + + for event in response["stream"]: + # Handle text content + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + await self.push_frame(LLMTextFrame(delta["text"])) + completion_tokens_estimate += self._estimate_tokens(delta["text"]) + elif "toolUse" in delta and "input" in delta["toolUse"]: + # Handle partial JSON for tool use + json_accumulator += delta["toolUse"]["input"] + completion_tokens_estimate += self._estimate_tokens( + delta["toolUse"]["input"] + ) + + # Handle tool use start + elif "contentBlockStart" in event: + content_block_start = event["contentBlockStart"]["start"] + if "toolUse" in content_block_start: + tool_use_block = { + "id": content_block_start["toolUse"].get("toolUseId", ""), + "name": content_block_start["toolUse"].get("name", ""), + } + json_accumulator = "" + + # Handle message completion with tool use + elif "messageStop" in event and "stopReason" in event["messageStop"]: + if event["messageStop"]["stopReason"] == "tool_use" and tool_use_block: + try: + arguments = json.loads(json_accumulator) if json_accumulator else {} + await self.call_function( + context=context, + tool_call_id=tool_use_block["id"], + function_name=tool_use_block["name"], + arguments=arguments, + ) + except json.JSONDecodeError: + logger.error(f"Failed to parse tool arguments: {json_accumulator}") + + # Handle usage metrics if available + if "metadata" in event and "usage" in event["metadata"]: + usage = event["metadata"]["usage"] + prompt_tokens += usage.get("inputTokens", 0) + completion_tokens += usage.get("outputTokens", 0) + cache_read_input_tokens += usage.get("cacheReadInputTokens", 0) + cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0) + + except asyncio.CancelledError: + # If we're interrupted, we won't get a complete usage report. So set our flag to use the + # token estimate. The reraise the exception so all the processors running in this task + # also get cancelled. + use_completion_tokens_estimate = True + raise + except httpx.TimeoutException: + await self._call_event_handler("on_completion_timeout") + except Exception as e: + logger.exception(f"{self} exception: {e}") + finally: + await self.stop_processing_metrics() + await self.push_frame(LLMFullResponseEndFrame()) + comp_tokens = ( + completion_tokens + if not use_completion_tokens_estimate + else completion_tokens_estimate + ) + await self._report_usage_metrics( + prompt_tokens=prompt_tokens, + completion_tokens=comp_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_creation_input_tokens=cache_creation_input_tokens, + ) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + context = None + if isinstance(frame, OpenAILLMContextFrame): + context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context) + elif isinstance(frame, LLMMessagesFrame): + context = AWSBedrockLLMContext.from_messages(frame.messages) + elif isinstance(frame, VisionImageRawFrame): + # This is only useful in very simple pipelines because it creates + # a new context. Generally we want a context manager to catch + # UserImageRawFrames coming through the pipeline and add them + # to the context. + context = AWSBedrockLLMContext.from_image_frame(frame) + elif isinstance(frame, LLMUpdateSettingsFrame): + await self._update_settings(frame.settings) + else: + await self.push_frame(frame, direction) + + if context: + await self._process_context(context) + + def _estimate_tokens(self, text: str) -> int: + return int(len(re.split(r"[^\w]+", text)) * 1.3) + + async def _report_usage_metrics( + self, + prompt_tokens: int, + completion_tokens: int, + cache_read_input_tokens: int, + cache_creation_input_tokens: int, + ): + if prompt_tokens or completion_tokens: + tokens = LLMTokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_creation_input_tokens=cache_creation_input_tokens, + ) + await self.start_llm_usage_metrics(tokens) diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py new file mode 100644 index 000000000..a02625f81 --- /dev/null +++ b/src/pipecat/services/aws/stt.py @@ -0,0 +1,329 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import json +import os +import random +import string +from typing import AsyncGenerator, Optional + +from loguru import logger + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + TranscriptionFrame, +) +from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url +from pipecat.services.stt_service import STTService +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + +try: + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.") + raise Exception(f"Missing module: {e}") + + +class AWSTranscribeSTTService(STTService): + def __init__( + self, + *, + api_key: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_session_token: Optional[str] = None, + region: Optional[str] = "us-east-1", + sample_rate: int = 16000, + language: Language = Language.EN, + **kwargs, + ): + super().__init__(**kwargs) + + self._settings = { + "sample_rate": sample_rate, + "language": language, + "media_encoding": "linear16", # AWS expects raw PCM + "number_of_channels": 1, + "show_speaker_label": False, + "enable_channel_identification": False, + } + + # Validate sample rate - AWS Transcribe only supports 8000 Hz or 16000 Hz + if sample_rate not in [8000, 16000]: + logger.warning( + f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. Converting from {sample_rate} Hz to 16000 Hz." + ) + self._settings["sample_rate"] = 16000 + + self._credentials = { + "aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"), + "aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + "region": region or os.getenv("AWS_REGION", "us-east-1"), + } + + self._ws_client = None + self._connection_lock = asyncio.Lock() + self._connecting = False + self._receive_task = None + + def get_service_encoding(self, encoding: str) -> str: + """Convert internal encoding format to AWS Transcribe format.""" + encoding_map = { + "linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM + } + return encoding_map.get(encoding, encoding) + + async def start(self, frame: StartFrame): + """Initialize the connection when the service starts.""" + await super().start(frame) + logger.info("Starting AWS Transcribe service...") + retry_count = 0 + max_retries = 3 + + while retry_count < max_retries: + try: + await self._connect() + if self._ws_client and self._ws_client.open: + logger.info("Successfully established WebSocket connection") + return + logger.warning("WebSocket connection not established after connect") + except Exception as e: + logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}") + retry_count += 1 + if retry_count < max_retries: + await asyncio.sleep(1) # Wait before retrying + + raise RuntimeError("Failed to establish WebSocket connection after multiple attempts") + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Process audio data and send to AWS Transcribe""" + try: + # Ensure WebSocket is connected + if not self._ws_client or not self._ws_client.open: + logger.debug("WebSocket not connected, attempting to reconnect...") + try: + await self._connect() + except Exception as e: + logger.error(f"Failed to reconnect: {e}") + yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False) + return + + # Format the audio data according to AWS event stream format + event_message = build_event_message(audio) + + # Send the formatted event message + try: + await self._ws_client.send(event_message) + # Start metrics after first chunk sent + await self.start_processing_metrics() + await self.start_ttfb_metrics() + except websockets.exceptions.ConnectionClosed as e: + logger.warning(f"Connection closed while sending: {e}") + await self._disconnect() + # Don't yield error here - we'll retry on next frame + except Exception as e: + logger.error(f"Error sending audio: {e}") + yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False) + await self._disconnect() + + except Exception as e: + logger.error(f"Error in run_stt: {e}") + yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False) + await self._disconnect() + + async def _connect(self): + """Connect to AWS Transcribe with connection state management.""" + if self._ws_client and self._ws_client.open and self._receive_task: + logger.debug(f"{self} Already connected") + return + + async with self._connection_lock: + if self._connecting: + logger.debug(f"{self} Connection already in progress") + return + + try: + self._connecting = True + logger.debug(f"{self} Starting connection process...") + + if self._ws_client: + await self._disconnect() + + language_code = self.language_to_service_language( + Language(self._settings["language"]) + ) + if not language_code: + raise ValueError(f"Unsupported language: {self._settings['language']}") + + # Generate random websocket key + websocket_key = "".join( + random.choices( + string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20 + ) + ) + + # Add required headers + extra_headers = { + "Origin": "https://localhost", + "Sec-WebSocket-Key": websocket_key, + "Sec-WebSocket-Version": "13", + "Connection": "keep-alive", + } + + # Get presigned URL + presigned_url = get_presigned_url( + region=self._credentials["region"], + credentials={ + "access_key": self._credentials["aws_access_key_id"], + "secret_key": self._credentials["aws_secret_access_key"], + "session_token": self._credentials["aws_session_token"], + }, + language_code=language_code, + media_encoding=self.get_service_encoding( + self._settings["media_encoding"] + ), # Convert to AWS format + sample_rate=self._settings["sample_rate"], + number_of_channels=self._settings["number_of_channels"], + enable_partial_results_stabilization=True, + partial_results_stability="high", + show_speaker_label=self._settings["show_speaker_label"], + enable_channel_identification=self._settings["enable_channel_identification"], + ) + + logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...") + + # Connect with the required headers and settings + self._ws_client = await websockets.connect( + presigned_url, + extra_headers=extra_headers, + subprotocols=["mqtt"], + ping_interval=None, + ping_timeout=None, + compression=None, + ) + + logger.debug(f"{self} WebSocket connected, starting receive task...") + + # Start receive task + self._receive_task = self.create_task(self._receive_loop()) + + logger.info(f"{self} Successfully connected to AWS Transcribe") + + except Exception as e: + logger.error(f"{self} Failed to connect to AWS Transcribe: {e}") + await self._disconnect() + raise + + finally: + self._connecting = False + + async def _disconnect(self): + """Disconnect from AWS Transcribe.""" + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + try: + if self._ws_client and self._ws_client.open: + # Send end-stream message + end_stream = {"message-type": "event", "event": "end"} + await self._ws_client.send(json.dumps(end_stream)) + await self._ws_client.close() + except Exception as e: + logger.warning(f"{self} Error closing WebSocket connection: {e}") + finally: + self._ws_client = None + + def language_to_service_language(self, language: Language) -> str | None: + """Convert internal language enum to AWS Transcribe language code.""" + language_map = { + Language.EN: "en-US", + Language.ES: "es-US", + Language.FR: "fr-FR", + Language.DE: "de-DE", + Language.IT: "it-IT", + Language.PT: "pt-BR", + Language.JA: "ja-JP", + Language.KO: "ko-KR", + Language.ZH: "zh-CN", + } + return language_map.get(language) + + async def _receive_loop(self): + """Background task to receive and process messages from AWS Transcribe.""" + while True: + if not self._ws_client or not self._ws_client.open: + logger.warning(f"{self} WebSocket closed in receive loop") + break + + try: + response = await self._ws_client.recv() + headers, payload = decode_event(response) + + if headers.get(":message-type") == "event": + # Process transcription results + results = payload.get("Transcript", {}).get("Results", []) + if results: + result = results[0] + alternatives = result.get("Alternatives", []) + if alternatives: + transcript = alternatives[0].get("Transcript", "") + is_final = not result.get("IsPartial", True) + + if transcript: + await self.stop_ttfb_metrics() + if is_final: + await self.push_frame( + TranscriptionFrame( + transcript, + "", + time_now_iso8601(), + self._settings["language"], + ) + ) + await self.stop_processing_metrics() + else: + await self.push_frame( + InterimTranscriptionFrame( + transcript, + "", + time_now_iso8601(), + self._settings["language"], + ) + ) + elif headers.get(":message-type") == "exception": + error_msg = payload.get("Message", "Unknown error") + logger.error(f"{self} Exception from AWS: {error_msg}") + await self.push_frame( + ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False) + ) + else: + logger.debug(f"{self} Other message type received: {headers}") + logger.debug(f"{self} Payload: {payload}") + except websockets.exceptions.ConnectionClosed as e: + logger.error( + f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}" + ) + break + except Exception as e: + logger.error(f"{self} Unexpected error in receive loop: {e}") + break diff --git a/src/pipecat/services/aws/tts.py b/src/pipecat/services/aws/tts.py index db6e168ab..40d746514 100644 --- a/src/pipecat/services/aws/tts.py +++ b/src/pipecat/services/aws/tts.py @@ -5,6 +5,7 @@ # import asyncio +import os from typing import AsyncGenerator, Optional from loguru import logger @@ -26,9 +27,7 @@ try: from botocore.exceptions import BotoCoreError, ClientError except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use Deepgram, you need to `pip install pipecat-ai[aws]`. Also, set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable." - ) + logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.") raise Exception(f"Missing module: {e}") @@ -108,7 +107,7 @@ def language_to_aws_language(language: Language) -> Optional[str]: return language_map.get(language) -class PollyTTSService(TTSService): +class AWSPollyTTSService(TTSService): class InputParams(BaseModel): engine: Optional[str] = None language: Optional[Language] = Language.EN @@ -151,6 +150,24 @@ class PollyTTSService(TTSService): self.set_voice(voice_id) + # Get credentials from environment variables if not provided + self._credentials = { + "aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"), + "aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + "region": region or os.getenv("AWS_REGION", "us-east-1"), + } + + # Validate that we have the required credentials + if ( + not self._credentials["aws_access_key_id"] + or not self._credentials["aws_secret_access_key"] + ): + raise ValueError( + "AWS credentials not found. Please provide them either through constructor parameters " + "or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables." + ) + def can_generate_metrics(self) -> bool: return True @@ -165,18 +182,17 @@ class PollyTTSService(TTSService): prosody_attrs = [] # Prosody tags are only supported for standard and neural engines - if self._settings["engine"] != "generative": - if self._settings["rate"]: - prosody_attrs.append(f"rate='{self._settings['rate']}'") + if self._settings["engine"] == "standard": if self._settings["pitch"]: prosody_attrs.append(f"pitch='{self._settings['pitch']}'") - if self._settings["volume"]: - prosody_attrs.append(f"volume='{self._settings['volume']}'") - if prosody_attrs: - ssml += f"" - else: - logger.warning("Prosody tags are not supported for generative engine. Ignoring.") + if self._settings["rate"]: + prosody_attrs.append(f"rate='{self._settings['rate']}'") + if self._settings["volume"]: + prosody_attrs.append(f"volume='{self._settings['volume']}'") + + if prosody_attrs: + ssml += f"" ssml += text @@ -187,6 +203,8 @@ class PollyTTSService(TTSService): ssml += "" + logger.trace(f"{self} SSML: {ssml}") + return ssml async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: @@ -248,3 +266,17 @@ class PollyTTSService(TTSService): finally: yield TTSStoppedFrame() + + +class PollyTTSService(AWSPollyTTSService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "'PollyTTSService' is deprecated, use 'AWSPollyTTSService' instead.", + DeprecationWarning, + ) diff --git a/src/pipecat/services/aws/utils.py b/src/pipecat/services/aws/utils.py new file mode 100644 index 000000000..db69456e9 --- /dev/null +++ b/src/pipecat/services/aws/utils.py @@ -0,0 +1,261 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import binascii +import datetime +import hashlib +import hmac +import json +import struct +import urllib.parse +from typing import Dict, Optional + + +def get_presigned_url( + *, + region: str, + credentials: Dict[str, Optional[str]], + language_code: str, + media_encoding: str = "pcm", + sample_rate: int = 16000, + number_of_channels: int = 1, + enable_partial_results_stabilization: bool = True, + partial_results_stability: str = "high", + vocabulary_name: Optional[str] = None, + vocabulary_filter_name: Optional[str] = None, + show_speaker_label: bool = False, + enable_channel_identification: bool = False, +) -> str: + """Create a presigned URL for AWS Transcribe streaming.""" + access_key = credentials.get("access_key") + secret_key = credentials.get("secret_key") + session_token = credentials.get("session_token") + + if not access_key or not secret_key: + raise ValueError("AWS credentials are required") + + # Initialize the URL generator + url_generator = AWSTranscribePresignedURL( + access_key=access_key, secret_key=secret_key, session_token=session_token, region=region + ) + + # Get the presigned URL + return url_generator.get_request_url( + sample_rate=sample_rate, + language_code=language_code, + media_encoding=media_encoding, + vocabulary_name=vocabulary_name, + vocabulary_filter_name=vocabulary_filter_name, + show_speaker_label=show_speaker_label, + enable_channel_identification=enable_channel_identification, + number_of_channels=number_of_channels, + enable_partial_results_stabilization=enable_partial_results_stabilization, + partial_results_stability=partial_results_stability, + ) + + +class AWSTranscribePresignedURL: + def __init__( + self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1" + ): + self.access_key = access_key + self.secret_key = secret_key + self.session_token = session_token + self.method = "GET" + self.service = "transcribe" + self.region = region + self.endpoint = "" + self.host = "" + self.amz_date = "" + self.datestamp = "" + self.canonical_uri = "/stream-transcription-websocket" + self.canonical_headers = "" + self.signed_headers = "host" + self.algorithm = "AWS4-HMAC-SHA256" + self.credential_scope = "" + self.canonical_querystring = "" + self.payload_hash = "" + self.canonical_request = "" + self.string_to_sign = "" + self.signature = "" + self.request_url = "" + + def get_request_url( + self, + sample_rate: int, + language_code: str = "", + media_encoding: str = "pcm", + vocabulary_name: str = "", + vocabulary_filter_name: str = "", + show_speaker_label: bool = False, + enable_channel_identification: bool = False, + number_of_channels: int = 1, + enable_partial_results_stabilization: bool = False, + partial_results_stability: str = "", + ) -> str: + self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443" + self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443" + + now = datetime.datetime.utcnow() + self.amz_date = now.strftime("%Y%m%dT%H%M%SZ") + self.datestamp = now.strftime("%Y%m%d") + self.canonical_headers = f"host:{self.host}\n" + self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request" + + # Create canonical querystring + self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm + self.canonical_querystring += ( + "&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope + ) + self.canonical_querystring += "&X-Amz-Date=" + self.amz_date + self.canonical_querystring += "&X-Amz-Expires=300" + if self.session_token: + self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote( + self.session_token, safe="" + ) + self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers + + if enable_channel_identification: + self.canonical_querystring += "&enable-channel-identification=true" + if enable_partial_results_stabilization: + self.canonical_querystring += "&enable-partial-results-stabilization=true" + if language_code: + self.canonical_querystring += "&language-code=" + language_code + if media_encoding: + self.canonical_querystring += "&media-encoding=" + media_encoding + if number_of_channels > 1: + self.canonical_querystring += "&number-of-channels=" + str(number_of_channels) + if partial_results_stability: + self.canonical_querystring += "&partial-results-stability=" + partial_results_stability + if sample_rate: + self.canonical_querystring += "&sample-rate=" + str(sample_rate) + if show_speaker_label: + self.canonical_querystring += "&show-speaker-label=true" + if vocabulary_filter_name: + self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name + if vocabulary_name: + self.canonical_querystring += "&vocabulary-name=" + vocabulary_name + + # Create payload hash + self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest() + + # Create canonical request + self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}" + + # Create string to sign + credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request" + string_to_sign = ( + f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n" + + hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest() + ) + + # Calculate signature + k_date = hmac.new( + f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256 + ).digest() + k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest() + k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest() + k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest() + self.signature = hmac.new( + k_signing, string_to_sign.encode("utf-8"), hashlib.sha256 + ).hexdigest() + + # Add signature to query string + self.canonical_querystring += "&X-Amz-Signature=" + self.signature + + # Create request URL + self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring + return self.request_url + + +def get_headers(header_name: str, header_value: str) -> bytearray: + """Build a header following AWS event stream format.""" + name = header_name.encode("utf-8") + name_byte_length = bytes([len(name)]) + value_type = bytes([7]) # 7 represents a string + value = header_value.encode("utf-8") + value_byte_length = struct.pack(">H", len(value)) + + # Construct the header + header_list = bytearray() + header_list.extend(name_byte_length) + header_list.extend(name) + header_list.extend(value_type) + header_list.extend(value_byte_length) + header_list.extend(value) + return header_list + + +def build_event_message(payload: bytes) -> bytes: + """ + Build an event message for AWS Transcribe streaming. + Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py + """ + # Build headers + content_type_header = get_headers(":content-type", "application/octet-stream") + event_type_header = get_headers(":event-type", "AudioEvent") + message_type_header = get_headers(":message-type", "event") + + headers = bytearray() + headers.extend(content_type_header) + headers.extend(event_type_header) + headers.extend(message_type_header) + + # Calculate total byte length and headers byte length + # 16 accounts for 8 byte prelude, 2x 4 byte CRCs + total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16) + headers_byte_length = struct.pack(">I", len(headers)) + + # Build the prelude + prelude = bytearray([0] * 8) + prelude[:4] = total_byte_length + prelude[4:] = headers_byte_length + + # Calculate checksum for prelude + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + + # Construct the message + message_as_list = bytearray() + message_as_list.extend(prelude) + message_as_list.extend(prelude_crc) + message_as_list.extend(headers) + message_as_list.extend(payload) + + # Calculate checksum for message + message = bytes(message_as_list) + message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF) + + # Add message checksum + message_as_list.extend(message_crc) + + return bytes(message_as_list) + + +def decode_event(message): + # Extract the prelude, headers, payload and CRC + prelude = message[:8] + total_length, headers_length = struct.unpack(">II", prelude) + prelude_crc = struct.unpack(">I", message[8:12])[0] + headers = message[12 : 12 + headers_length] + payload = message[12 + headers_length : -4] + message_crc = struct.unpack(">I", message[-4:])[0] + + # Check the CRCs + assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed" + assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed" + + # Parse the headers + headers_dict = {} + while headers: + name_len = headers[0] + name = headers[1 : 1 + name_len].decode("utf-8") + value_type = headers[1 + name_len] + value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0] + value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8") + headers_dict[name] = value + headers = headers[4 + name_len + value_len :] + + return headers_dict, json.loads(payload) diff --git a/test-requirements.txt b/test-requirements.txt index b34a53ab9..fec8adf52 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1 +1 @@ --e ".[anthropic,google,langchain]" +-e ".[anthropic,aws,google,langchain]" diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index dfe210e07..0f68110ce 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -40,6 +40,11 @@ from pipecat.services.anthropic.llm import ( AnthropicLLMContext, AnthropicUserContextAggregator, ) +from pipecat.services.aws.llm import ( + AWSBedrockAssistantContextAggregator, + AWSBedrockLLMContext, + AWSBedrockUserContextAggregator, +) from pipecat.services.google.llm import ( GoogleAssistantContextAggregator, GoogleLLMContext, @@ -669,26 +674,6 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola AGGREGATOR_CLASS = LLMUserContextAggregator -# -# OpenAI -# - - -class TestOpenAIUserContextAggregator( - BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = OpenAIUserContextAggregator - - -class TestOpenAIAssistantContextAggregator( - BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = OpenAIAssistantContextAggregator - EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] - - # # Anthropic # @@ -724,6 +709,43 @@ class TestAnthropicAssistantContextAggregator( assert context.messages[index]["content"][0]["content"] == json.dumps(content) +# +# AWS (Bedrock) +# + + +class TestAWSBedrockUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = AWSBedrockLLMContext + AGGREGATOR_CLASS = AWSBedrockUserContextAggregator + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + messages = context.messages[content_index] + assert messages["content"][index]["text"] == content + + +class TestAWSBedrockAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = AWSBedrockLLMContext + AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + messages = context.messages[content_index] + assert messages["content"][index]["text"] == content + + def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): + assert context.messages[index]["content"][0]["toolResult"]["content"][0][ + "text" + ] == json.dumps(content) + + # # Google # @@ -766,3 +788,23 @@ class TestGoogleAssistantContextAggregator( def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): obj = glm.Content.to_dict(context.messages[index]) assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content) + + +# +# OpenAI +# + + +class TestOpenAIUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = OpenAIUserContextAggregator + + +class TestOpenAIAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = OpenAIAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] diff --git a/tests/test_function_calling_adapters.py b/tests/test_function_calling_adapters.py index 5d6dafce3..83640bb80 100644 --- a/tests/test_function_calling_adapters.py +++ b/tests/test_function_calling_adapters.py @@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionToolParam from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter +from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter @@ -174,3 +175,32 @@ class TestFunctionAdapters(unittest.TestCase): tools_def = self.tools_def tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]} assert GeminiLLMAdapter().to_provider_tools_format(tools_def) == expected + + def test_bedrock_adapter(self): + """Test AWS Bedrock adapter format transformation.""" + expected = [ + { + "toolSpec": { + "name": "get_weather", + "description": "Get the weather in a given location", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use.", + }, + "location": { + "type": "string", + "description": "The city, e.g. San Francisco", + }, + }, + "required": ["location", "format"], + } + }, + } + } + ] + assert AWSBedrockLLMAdapter().to_provider_tools_format(self.tools_def) == expected