Merge pull request #1753 from pipecat-ai/aleix/add-bedrock-support

Add support for Amazon Bedrock LLMs
This commit is contained in:
Aleix Conchillo Flaqué
2025-05-07 09:31:48 -07:00
committed by GitHub
15 changed files with 1725 additions and 62 deletions

View File

@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added new AWS services `AWSBedrockLLMService` and `AWSTranscribeSTTService`.
- Added `on_active_speaker_changed` event handler to the `DailyTransport` class.
- Added `enable_ssml_parsing` and `enable_logging` to `InputParams` in
@@ -25,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecated
- `PollyTTSService` is now deprecated, use `AWSPollyTTSService` instead.
- Observer `on_push_frame(src, dst, frame, direction, timestamp)` is now
deprecated, use `on_push_frame(data: FramePushed)` instead.

View File

@@ -49,18 +49,18 @@ You can connect to Pipecat from any platform using our official SDKs:
## 🧩 Available services
| Category | Services |
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
| Category | Services |
|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)

View File

@@ -5,7 +5,6 @@
#
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
@@ -15,9 +14,9 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.aws.tts import PollyTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.aws.stt import AWSTranscribeSTTService
from pipecat.services.aws.tts import AWSPollyTTSService
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
@@ -37,17 +36,19 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
),
)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
stt = AWSTranscribeSTTService()
tts = PollyTTSService(
api_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
region=os.getenv("AWS_REGION"),
voice_id="Amy",
params=PollyTTSService.InputParams(engine="neural", language="en-GB", rate="1.05"),
tts = AWSPollyTTSService(
region="us-west-2", # only specific regions support generative TTS
voice_id="Joanna",
params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"),
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
llm = AWSBedrockLLMService(
aws_region="us-west-2",
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"),
)
messages = [
{
@@ -85,7 +86,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")

View File

@@ -0,0 +1,139 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.aws.stt import AWSTranscribeSTTService
from pipecat.services.aws.tts import AWSPollyTTSService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
load_dotenv(override=True)
async def fetch_weather_from_api(params: FunctionCallParams):
await params.result_callback({"conditions": "nice", "temperature": "75"})
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
logger.info(f"Starting bot")
transport = SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
stt = AWSTranscribeSTTService()
tts = AWSPollyTTSService(
region="us-west-2", # only specific regions support generative TTS
voice_id="Joanna",
params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"),
)
llm = AWSBedrockLLMService(
aws_region="us-west-2",
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"),
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_function("get_current_weather", fetch_weather_from_api)
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["location", "format"],
)
tools = ToolsSchema(standard_tools=[weather_function])
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
@transport.event_handler("on_client_closed")
async def on_client_closed(transport, client):
logger.info(f"Client closed connection")
await task.cancel()
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
if __name__ == "__main__":
from run import main
main()

View File

@@ -41,7 +41,7 @@ Website = "https://pipecat.ai"
[project.optional-dependencies]
anthropic = [ "anthropic~=0.49.0" ]
assemblyai = [ "assemblyai~=0.37.0" ]
aws = [ "boto3~=1.37.16" ]
aws = [ "boto3~=1.37.16", "websockets~=13.1" ]
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ]
cerebras = []

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Any, Dict, List, Union
from typing import Any, Dict, List
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema

View File

@@ -0,0 +1,38 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Any, Dict, List
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
class AWSBedrockLLMAdapter(BaseLLMAdapter):
@staticmethod
def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]:
return {
"toolSpec": {
"name": function.name,
"description": function.description,
"inputSchema": {
"json": {
"type": "object",
"properties": function.properties,
"required": function.required,
},
},
}
}
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
"""Converts function schemas to Bedrock's function-calling format.
:return: Bedrock formatted function call definition.
"""
functions_schema = tools_schema.standard_tools
return [self._to_bedrock_function_format(func) for func in functions_schema]

View File

@@ -8,6 +8,8 @@ import sys
from pipecat.services import DeprecatedModuleProxy
from .llm import *
from .stt import *
from .tts import *
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts")
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.[llm,stt,tts]")

View File

@@ -0,0 +1,785 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import base64
import copy
import io
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.frames.frames import (
Frame,
FunctionCallCancelFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMTextFrame,
LLMUpdateSettingsFrame,
UserImageRawFrame,
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import LLMService
try:
import boto3
import httpx
from botocore.config import Config
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
)
raise Exception(f"Missing module: {e}")
@dataclass
class AWSBedrockContextAggregatorPair:
_user: "AWSBedrockUserContextAggregator"
_assistant: "AWSBedrockAssistantContextAggregator"
def user(self) -> "AWSBedrockUserContextAggregator":
return self._user
def assistant(self) -> "AWSBedrockAssistantContextAggregator":
return self._assistant
class AWSBedrockLLMContext(OpenAILLMContext):
def __init__(
self,
messages: Optional[List[dict]] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[dict] = None,
*,
system: Optional[str] = None,
):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self.system = system
@staticmethod
def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext":
logger.debug(f"Upgrading to AWS Bedrock: {obj}")
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext):
obj.__class__ = AWSBedrockLLMContext
obj._restructure_from_openai_messages()
else:
obj._restructure_from_bedrock_messages()
return obj
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
self = cls(
messages=openai_context.messages,
tools=openai_context.tools,
tool_choice=openai_context.tool_choice,
)
self.set_llm_adapter(openai_context.get_llm_adapter())
self._restructure_from_openai_messages()
return self
@classmethod
def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext":
self = cls(messages=messages)
self._restructure_from_openai_messages()
return self
@classmethod
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AWSBedrockLLMContext":
context = cls()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
)
return context
def set_messages(self, messages: List):
self._messages[:] = messages
self._restructure_from_openai_messages()
# convert a message in AWS Bedrock format into one or more messages in OpenAI format
def to_standard_messages(self, obj):
"""Convert AWS Bedrock message format to standard structured format.
Handles text content and function calls for both user and assistant messages.
Args:
obj: Message in AWS Bedrock format:
{
"role": "user/assistant",
"content": [{"text": str} | {"toolUse": {...}} | {"toolResult": {...}}]
}
Returns:
List of messages in standard format:
[
{
"role": "user/assistant/tool",
"content": [{"type": "text", "text": str}]
}
]
"""
role = obj.get("role")
content = obj.get("content")
if role == "assistant":
if isinstance(content, str):
return [{"role": role, "content": [{"type": "text", "text": content}]}]
elif isinstance(content, list):
text_items = []
tool_items = []
for item in content:
if "text" in item:
text_items.append({"type": "text", "text": item["text"]})
elif "toolUse" in item:
tool_use = item["toolUse"]
tool_items.append(
{
"type": "function",
"id": tool_use["toolUseId"],
"function": {
"name": tool_use["name"],
"arguments": json.dumps(tool_use["input"]),
},
}
)
messages = []
if text_items:
messages.append({"role": role, "content": text_items})
if tool_items:
messages.append({"role": role, "tool_calls": tool_items})
return messages
elif role == "user":
if isinstance(content, str):
return [{"role": role, "content": [{"type": "text", "text": content}]}]
elif isinstance(content, list):
text_items = []
tool_items = []
for item in content:
if "text" in item:
text_items.append({"type": "text", "text": item["text"]})
elif "toolResult" in item:
tool_result = item["toolResult"]
# Extract content from toolResult
result_content = ""
if isinstance(tool_result["content"], list):
for content_item in tool_result["content"]:
if "text" in content_item:
result_content = content_item["text"]
elif "json" in content_item:
result_content = json.dumps(content_item["json"])
else:
result_content = tool_result["content"]
tool_items.append(
{
"role": "tool",
"tool_call_id": tool_result["toolUseId"],
"content": result_content,
}
)
messages = []
if text_items:
messages.append({"role": role, "content": text_items})
messages.extend(tool_items)
return messages
def from_standard_message(self, message):
"""Convert standard format message to AWS Bedrock format.
Handles conversion of text content, tool calls, and tool results.
Empty text content is converted to "(empty)".
Args:
message: Message in standard format:
{
"role": "user/assistant/tool",
"content": str | [{"type": "text", ...}],
"tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}]
}
Returns:
Message in AWS Bedrock format:
{
"role": "user/assistant",
"content": [
{"text": str} |
{"toolUse": {"toolUseId": str, "name": str, "input": dict}} |
{"toolResult": {"toolUseId": str, "content": [...], "status": str}}
]
}
"""
if message["role"] == "tool":
# Try to parse the content as JSON if it looks like JSON
try:
if message["content"].strip().startswith("{") and message[
"content"
].strip().endswith("}"):
content_json = json.loads(message["content"])
tool_result_content = [{"json": content_json}]
else:
tool_result_content = [{"text": message["content"]}]
except:
tool_result_content = [{"text": message["content"]}]
return {
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": message["tool_call_id"],
"content": tool_result_content,
},
},
],
}
if message.get("tool_calls"):
tc = message["tool_calls"]
ret = {"role": "assistant", "content": []}
for tool_call in tc:
function = tool_call["function"]
arguments = json.loads(function["arguments"])
new_tool_use = {
"toolUse": {
"toolUseId": tool_call["id"],
"name": function["name"],
"input": arguments,
}
}
ret["content"].append(new_tool_use)
return ret
# Handle text content
content = message.get("content")
if isinstance(content, str):
if content == "":
return {"role": message["role"], "content": [{"text": "(empty)"}]}
else:
return {"role": message["role"], "content": [{"text": content}]}
elif isinstance(content, list):
new_content = []
for item in content:
if item.get("type", "") == "text":
text_content = item["text"] if item["text"] != "" else "(empty)"
new_content.append({"text": text_content})
return {"role": message["role"], "content": new_content}
return message
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Image should be the first content block in the message
content = [{"type": "image", "format": "jpeg", "source": {"bytes": encoded_image}}]
if text:
content.append({"text": text})
self.add_message({"role": "user", "content": content})
def add_message(self, message):
try:
if self.messages:
# AWS Bedrock requires that roles alternate. If this message's
# role is the same as the last message, we should add this
# message's content to the last message.
if self.messages[-1]["role"] == message["role"]:
# if the last message has just a content string, convert it to a list
# in the proper format
if isinstance(self.messages[-1]["content"], str):
self.messages[-1]["content"] = [{"text": self.messages[-1]["content"]}]
# if this message has just a content string, convert it to a list
# in the proper format
if isinstance(message["content"], str):
message["content"] = [{"text": message["content"]}]
# append the content of this message to the last message
self.messages[-1]["content"].extend(message["content"])
else:
self.messages.append(message)
else:
self.messages.append(message)
except Exception as e:
logger.error(f"Error adding message: {e}")
def _restructure_from_bedrock_messages(self):
"""Restructure messages in AWS Bedrock format by handling system
messages, merging consecutive messages with the same role, and ensuring
proper content formatting.
"""
# Handle system message if present at the beginning
if self.messages and self.messages[0]["role"] == "system":
if len(self.messages) == 1:
self.messages[0]["role"] = "user"
else:
system_content = self.messages.pop(0)["content"]
if isinstance(system_content, str):
system_content = [{"text": system_content}]
if self.system:
if isinstance(self.system, str):
self.system = [{"text": self.system}]
self.system.extend(system_content)
else:
self.system = system_content
# Ensure content is properly formatted
for msg in self.messages:
if isinstance(msg["content"], str):
msg["content"] = [{"text": msg["content"]}]
elif not msg["content"]:
msg["content"] = [{"text": "(empty)"}]
elif isinstance(msg["content"], list):
for idx, item in enumerate(msg["content"]):
if isinstance(item, dict) and "text" in item and item["text"] == "":
item["text"] = "(empty)"
elif isinstance(item, str) and item == "":
msg["content"][idx] = {"text": "(empty)"}
# Merge consecutive messages with the same role
merged_messages = []
for msg in self.messages:
if merged_messages and merged_messages[-1]["role"] == msg["role"]:
merged_messages[-1]["content"].extend(msg["content"])
else:
merged_messages.append(msg)
self.messages.clear()
self.messages.extend(merged_messages)
def _restructure_from_openai_messages(self):
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
try:
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
except Exception as e:
logger.error(f"Error mapping messages: {e}")
# See if we should pull the system message out of our context.messages list. (For
# compatibility with Open AI messages format.)
if self.messages and self.messages[0]["role"] == "system":
self.system = self.messages[0]["content"]
self.messages.pop(0)
# Merge consecutive messages with the same role.
i = 0
while i < len(self.messages) - 1:
current_message = self.messages[i]
next_message = self.messages[i + 1]
if current_message["role"] == next_message["role"]:
# Convert content to list of dictionaries if it's a string
if isinstance(current_message["content"], str):
current_message["content"] = [
{"type": "text", "text": current_message["content"]}
]
if isinstance(next_message["content"], str):
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
# Concatenate the content
current_message["content"].extend(next_message["content"])
# Remove the next message from the list
self.messages.pop(i + 1)
else:
i += 1
# Avoid empty content in messages
for message in self.messages:
if isinstance(message["content"], str) and message["content"] == "":
message["content"] = "(empty)"
elif isinstance(message["content"], list) and len(message["content"]) == 0:
message["content"] = [{"type": "text", "text": "(empty)"}]
def get_messages_for_persistent_storage(self):
messages = super().get_messages_for_persistent_storage()
if self.system:
messages.insert(0, {"role": "system", "content": self.system})
return messages
def get_messages_for_logging(self) -> str:
msgs = []
for message in self.messages:
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
for item in msg["content"]:
if item.get("image"):
item["source"]["bytes"] = "..."
msgs.append(msg)
return json.dumps(msgs)
class AWSBedrockUserContextAggregator(LLMUserContextAggregator):
pass
class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
# Format tool use according to AWS Bedrock API
self._context.add_message(
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": frame.tool_call_id,
"name": frame.function_name,
"input": frame.arguments if frame.arguments else {},
}
}
],
}
)
self._context.add_message(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": frame.tool_call_id,
"content": [{"text": "IN_PROGRESS"}],
}
}
],
}
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
else:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "COMPLETED"
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: Any
):
for message in self._context.messages:
if message["role"] == "user":
for content in message["content"]:
if (
isinstance(content, dict)
and content.get("toolResult")
and content["toolResult"]["toolUseId"] == tool_call_id
):
content["toolResult"]["content"] = [{"text": result}]
async def handle_user_image_frame(self, frame: UserImageRawFrame):
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)
class AWSBedrockLLMService(LLMService):
"""This class implements inference with AWS Bedrock models including Amazon
Nova and Anthropic Claude.
Requires AWS credentials to be configured in the environment or through
boto3 configuration.
"""
# Overriding the default adapter to use the Anthropic one.
adapter_class = AWSBedrockLLMAdapter
class InputParams(BaseModel):
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
temperature: Optional[float] = Field(default_factory=lambda: 0.7, ge=0.0, le=1.0)
top_p: Optional[float] = Field(default_factory=lambda: 0.999, ge=0.0, le=1.0)
stop_sequences: Optional[List[str]] = Field(default_factory=lambda: [])
latency: Optional[str] = Field(default_factory=lambda: "standard")
additional_model_request_fields: Optional[Dict[str, Any]] = Field(default_factory=dict)
def __init__(
self,
*,
aws_access_key: Optional[str] = None,
aws_secret_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region: str = "us-east-1",
model: str,
params: InputParams = InputParams(),
client_config: Optional[Config] = None,
**kwargs,
):
super().__init__(**kwargs)
# Initialize the AWS Bedrock client
if not client_config:
client_config = Config(
connect_timeout=300, # 5 minutes
read_timeout=300, # 5 minutes
retries={"max_attempts": 3},
)
session = boto3.Session(
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
region_name=aws_region,
)
self._client = session.client(service_name="bedrock-runtime", config=client_config)
self.set_model_name(model)
self._settings = {
"max_tokens": params.max_tokens,
"temperature": params.temperature,
"top_p": params.top_p,
"latency": params.latency,
"additional_model_request_fields": params.additional_model_request_fields
if isinstance(params.additional_model_request_fields, dict)
else {},
}
logger.info(f"Using AWS Bedrock model: {model}")
def can_generate_metrics(self) -> bool:
return True
def create_context_aggregator(
self,
context: OpenAILLMContext,
*,
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> AWSBedrockContextAggregatorPair:
"""Create an instance of AWSBedrockContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
Returns:
AWSBedrockContextAggregatorPair: A pair of context aggregators, one
for the user and one for the assistant, encapsulated in an
AWSBedrockContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
if isinstance(context, OpenAILLMContext):
context = AWSBedrockLLMContext.from_openai_context(context)
user = AWSBedrockUserContextAggregator(context, params=user_params)
assistant = AWSBedrockAssistantContextAggregator(context, params=assistant_params)
return AWSBedrockContextAggregatorPair(_user=user, _assistant=assistant)
async def _process_context(self, context: AWSBedrockLLMContext):
# Usage tracking
prompt_tokens = 0
completion_tokens = 0
completion_tokens_estimate = 0
cache_read_input_tokens = 0
cache_creation_input_tokens = 0
use_completion_tokens_estimate = False
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self.start_ttfb_metrics()
# Set up inference config
inference_config = {
"maxTokens": self._settings["max_tokens"],
"temperature": self._settings["temperature"],
"topP": self._settings["top_p"],
}
# Prepare request parameters
request_params = {
"modelId": self.model_name,
"messages": context.messages,
"inferenceConfig": inference_config,
"additionalModelRequestFields": self._settings["additional_model_request_fields"],
}
# Add system message
request_params["system"] = context.system
# Add tools if present
if context.tools:
tool_config = {"tools": context.tools}
# Add tool_choice if specified
if context.tool_choice:
if context.tool_choice == "auto":
tool_config["toolChoice"] = {"auto": {}}
elif context.tool_choice == "none":
# Skip adding toolChoice for "none"
pass
elif (
isinstance(context.tool_choice, dict) and "function" in context.tool_choice
):
tool_config["toolChoice"] = {
"tool": {"name": context.tool_choice["function"]["name"]}
}
request_params["toolConfig"] = tool_config
# Add performance config if latency is specified
if self._settings["latency"] in ["standard", "optimized"]:
request_params["performanceConfig"] = {"latency": self._settings["latency"]}
logger.debug(f"Calling AWS Bedrock model with: {request_params}")
# Call AWS Bedrock with streaming
response = self._client.converse_stream(**request_params)
await self.stop_ttfb_metrics()
# Process the streaming response
tool_use_block = None
json_accumulator = ""
for event in response["stream"]:
# Handle text content
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
await self.push_frame(LLMTextFrame(delta["text"]))
completion_tokens_estimate += self._estimate_tokens(delta["text"])
elif "toolUse" in delta and "input" in delta["toolUse"]:
# Handle partial JSON for tool use
json_accumulator += delta["toolUse"]["input"]
completion_tokens_estimate += self._estimate_tokens(
delta["toolUse"]["input"]
)
# Handle tool use start
elif "contentBlockStart" in event:
content_block_start = event["contentBlockStart"]["start"]
if "toolUse" in content_block_start:
tool_use_block = {
"id": content_block_start["toolUse"].get("toolUseId", ""),
"name": content_block_start["toolUse"].get("name", ""),
}
json_accumulator = ""
# Handle message completion with tool use
elif "messageStop" in event and "stopReason" in event["messageStop"]:
if event["messageStop"]["stopReason"] == "tool_use" and tool_use_block:
try:
arguments = json.loads(json_accumulator) if json_accumulator else {}
await self.call_function(
context=context,
tool_call_id=tool_use_block["id"],
function_name=tool_use_block["name"],
arguments=arguments,
)
except json.JSONDecodeError:
logger.error(f"Failed to parse tool arguments: {json_accumulator}")
# Handle usage metrics if available
if "metadata" in event and "usage" in event["metadata"]:
usage = event["metadata"]["usage"]
prompt_tokens += usage.get("inputTokens", 0)
completion_tokens += usage.get("outputTokens", 0)
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
except asyncio.CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
# token estimate. The reraise the exception so all the processors running in this task
# also get cancelled.
use_completion_tokens_estimate = True
raise
except httpx.TimeoutException:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
comp_tokens = (
completion_tokens
if not use_completion_tokens_estimate
else completion_tokens_estimate
)
await self._report_usage_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=comp_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context)
elif isinstance(frame, LLMMessagesFrame):
context = AWSBedrockLLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
# This is only useful in very simple pipelines because it creates
# a new context. Generally we want a context manager to catch
# UserImageRawFrames coming through the pipeline and add them
# to the context.
context = AWSBedrockLLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)
if context:
await self._process_context(context)
def _estimate_tokens(self, text: str) -> int:
return int(len(re.split(r"[^\w]+", text)) * 1.3)
async def _report_usage_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int,
):
if prompt_tokens or completion_tokens:
tokens = LLMTokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
)
await self.start_llm_usage_metrics(tokens)

View File

@@ -0,0 +1,329 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import json
import os
import random
import string
from typing import AsyncGenerator, Optional
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
raise Exception(f"Missing module: {e}")
class AWSTranscribeSTTService(STTService):
def __init__(
self,
*,
api_key: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_session_token: Optional[str] = None,
region: Optional[str] = "us-east-1",
sample_rate: int = 16000,
language: Language = Language.EN,
**kwargs,
):
super().__init__(**kwargs)
self._settings = {
"sample_rate": sample_rate,
"language": language,
"media_encoding": "linear16", # AWS expects raw PCM
"number_of_channels": 1,
"show_speaker_label": False,
"enable_channel_identification": False,
}
# Validate sample rate - AWS Transcribe only supports 8000 Hz or 16000 Hz
if sample_rate not in [8000, 16000]:
logger.warning(
f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. Converting from {sample_rate} Hz to 16000 Hz."
)
self._settings["sample_rate"] = 16000
self._credentials = {
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
"region": region or os.getenv("AWS_REGION", "us-east-1"),
}
self._ws_client = None
self._connection_lock = asyncio.Lock()
self._connecting = False
self._receive_task = None
def get_service_encoding(self, encoding: str) -> str:
"""Convert internal encoding format to AWS Transcribe format."""
encoding_map = {
"linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM
}
return encoding_map.get(encoding, encoding)
async def start(self, frame: StartFrame):
"""Initialize the connection when the service starts."""
await super().start(frame)
logger.info("Starting AWS Transcribe service...")
retry_count = 0
max_retries = 3
while retry_count < max_retries:
try:
await self._connect()
if self._ws_client and self._ws_client.open:
logger.info("Successfully established WebSocket connection")
return
logger.warning("WebSocket connection not established after connect")
except Exception as e:
logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}")
retry_count += 1
if retry_count < max_retries:
await asyncio.sleep(1) # Wait before retrying
raise RuntimeError("Failed to establish WebSocket connection after multiple attempts")
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data and send to AWS Transcribe"""
try:
# Ensure WebSocket is connected
if not self._ws_client or not self._ws_client.open:
logger.debug("WebSocket not connected, attempting to reconnect...")
try:
await self._connect()
except Exception as e:
logger.error(f"Failed to reconnect: {e}")
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
return
# Format the audio data according to AWS event stream format
event_message = build_event_message(audio)
# Send the formatted event message
try:
await self._ws_client.send(event_message)
# Start metrics after first chunk sent
await self.start_processing_metrics()
await self.start_ttfb_metrics()
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"Connection closed while sending: {e}")
await self._disconnect()
# Don't yield error here - we'll retry on next frame
except Exception as e:
logger.error(f"Error sending audio: {e}")
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
await self._disconnect()
except Exception as e:
logger.error(f"Error in run_stt: {e}")
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
await self._disconnect()
async def _connect(self):
"""Connect to AWS Transcribe with connection state management."""
if self._ws_client and self._ws_client.open and self._receive_task:
logger.debug(f"{self} Already connected")
return
async with self._connection_lock:
if self._connecting:
logger.debug(f"{self} Connection already in progress")
return
try:
self._connecting = True
logger.debug(f"{self} Starting connection process...")
if self._ws_client:
await self._disconnect()
language_code = self.language_to_service_language(
Language(self._settings["language"])
)
if not language_code:
raise ValueError(f"Unsupported language: {self._settings['language']}")
# Generate random websocket key
websocket_key = "".join(
random.choices(
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
)
)
# Add required headers
extra_headers = {
"Origin": "https://localhost",
"Sec-WebSocket-Key": websocket_key,
"Sec-WebSocket-Version": "13",
"Connection": "keep-alive",
}
# Get presigned URL
presigned_url = get_presigned_url(
region=self._credentials["region"],
credentials={
"access_key": self._credentials["aws_access_key_id"],
"secret_key": self._credentials["aws_secret_access_key"],
"session_token": self._credentials["aws_session_token"],
},
language_code=language_code,
media_encoding=self.get_service_encoding(
self._settings["media_encoding"]
), # Convert to AWS format
sample_rate=self._settings["sample_rate"],
number_of_channels=self._settings["number_of_channels"],
enable_partial_results_stabilization=True,
partial_results_stability="high",
show_speaker_label=self._settings["show_speaker_label"],
enable_channel_identification=self._settings["enable_channel_identification"],
)
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
# Connect with the required headers and settings
self._ws_client = await websockets.connect(
presigned_url,
extra_headers=extra_headers,
subprotocols=["mqtt"],
ping_interval=None,
ping_timeout=None,
compression=None,
)
logger.debug(f"{self} WebSocket connected, starting receive task...")
# Start receive task
self._receive_task = self.create_task(self._receive_loop())
logger.info(f"{self} Successfully connected to AWS Transcribe")
except Exception as e:
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
await self._disconnect()
raise
finally:
self._connecting = False
async def _disconnect(self):
"""Disconnect from AWS Transcribe."""
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
try:
if self._ws_client and self._ws_client.open:
# Send end-stream message
end_stream = {"message-type": "event", "event": "end"}
await self._ws_client.send(json.dumps(end_stream))
await self._ws_client.close()
except Exception as e:
logger.warning(f"{self} Error closing WebSocket connection: {e}")
finally:
self._ws_client = None
def language_to_service_language(self, language: Language) -> str | None:
"""Convert internal language enum to AWS Transcribe language code."""
language_map = {
Language.EN: "en-US",
Language.ES: "es-US",
Language.FR: "fr-FR",
Language.DE: "de-DE",
Language.IT: "it-IT",
Language.PT: "pt-BR",
Language.JA: "ja-JP",
Language.KO: "ko-KR",
Language.ZH: "zh-CN",
}
return language_map.get(language)
async def _receive_loop(self):
"""Background task to receive and process messages from AWS Transcribe."""
while True:
if not self._ws_client or not self._ws_client.open:
logger.warning(f"{self} WebSocket closed in receive loop")
break
try:
response = await self._ws_client.recv()
headers, payload = decode_event(response)
if headers.get(":message-type") == "event":
# Process transcription results
results = payload.get("Transcript", {}).get("Results", [])
if results:
result = results[0]
alternatives = result.get("Alternatives", [])
if alternatives:
transcript = alternatives[0].get("Transcript", "")
is_final = not result.get("IsPartial", True)
if transcript:
await self.stop_ttfb_metrics()
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
"",
time_now_iso8601(),
self._settings["language"],
)
)
await self.stop_processing_metrics()
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
"",
time_now_iso8601(),
self._settings["language"],
)
)
elif headers.get(":message-type") == "exception":
error_msg = payload.get("Message", "Unknown error")
logger.error(f"{self} Exception from AWS: {error_msg}")
await self.push_frame(
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
)
else:
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
)
break
except Exception as e:
logger.error(f"{self} Unexpected error in receive loop: {e}")
break

View File

@@ -5,6 +5,7 @@
#
import asyncio
import os
from typing import AsyncGenerator, Optional
from loguru import logger
@@ -26,9 +27,7 @@ try:
from botocore.exceptions import BotoCoreError, ClientError
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Deepgram, you need to `pip install pipecat-ai[aws]`. Also, set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
)
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
raise Exception(f"Missing module: {e}")
@@ -108,7 +107,7 @@ def language_to_aws_language(language: Language) -> Optional[str]:
return language_map.get(language)
class PollyTTSService(TTSService):
class AWSPollyTTSService(TTSService):
class InputParams(BaseModel):
engine: Optional[str] = None
language: Optional[Language] = Language.EN
@@ -151,6 +150,24 @@ class PollyTTSService(TTSService):
self.set_voice(voice_id)
# Get credentials from environment variables if not provided
self._credentials = {
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
"region": region or os.getenv("AWS_REGION", "us-east-1"),
}
# Validate that we have the required credentials
if (
not self._credentials["aws_access_key_id"]
or not self._credentials["aws_secret_access_key"]
):
raise ValueError(
"AWS credentials not found. Please provide them either through constructor parameters "
"or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
)
def can_generate_metrics(self) -> bool:
return True
@@ -165,18 +182,17 @@ class PollyTTSService(TTSService):
prosody_attrs = []
# Prosody tags are only supported for standard and neural engines
if self._settings["engine"] != "generative":
if self._settings["rate"]:
prosody_attrs.append(f"rate='{self._settings['rate']}'")
if self._settings["engine"] == "standard":
if self._settings["pitch"]:
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
if self._settings["volume"]:
prosody_attrs.append(f"volume='{self._settings['volume']}'")
if prosody_attrs:
ssml += f"<prosody {' '.join(prosody_attrs)}>"
else:
logger.warning("Prosody tags are not supported for generative engine. Ignoring.")
if self._settings["rate"]:
prosody_attrs.append(f"rate='{self._settings['rate']}'")
if self._settings["volume"]:
prosody_attrs.append(f"volume='{self._settings['volume']}'")
if prosody_attrs:
ssml += f"<prosody {' '.join(prosody_attrs)}>"
ssml += text
@@ -187,6 +203,8 @@ class PollyTTSService(TTSService):
ssml += "</speak>"
logger.trace(f"{self} SSML: {ssml}")
return ssml
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
@@ -248,3 +266,17 @@ class PollyTTSService(TTSService):
finally:
yield TTSStoppedFrame()
class PollyTTSService(AWSPollyTTSService):
def __init__(self, **kwargs):
super().__init__(**kwargs)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"'PollyTTSService' is deprecated, use 'AWSPollyTTSService' instead.",
DeprecationWarning,
)

View File

@@ -0,0 +1,261 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import binascii
import datetime
import hashlib
import hmac
import json
import struct
import urllib.parse
from typing import Dict, Optional
def get_presigned_url(
*,
region: str,
credentials: Dict[str, Optional[str]],
language_code: str,
media_encoding: str = "pcm",
sample_rate: int = 16000,
number_of_channels: int = 1,
enable_partial_results_stabilization: bool = True,
partial_results_stability: str = "high",
vocabulary_name: Optional[str] = None,
vocabulary_filter_name: Optional[str] = None,
show_speaker_label: bool = False,
enable_channel_identification: bool = False,
) -> str:
"""Create a presigned URL for AWS Transcribe streaming."""
access_key = credentials.get("access_key")
secret_key = credentials.get("secret_key")
session_token = credentials.get("session_token")
if not access_key or not secret_key:
raise ValueError("AWS credentials are required")
# Initialize the URL generator
url_generator = AWSTranscribePresignedURL(
access_key=access_key, secret_key=secret_key, session_token=session_token, region=region
)
# Get the presigned URL
return url_generator.get_request_url(
sample_rate=sample_rate,
language_code=language_code,
media_encoding=media_encoding,
vocabulary_name=vocabulary_name,
vocabulary_filter_name=vocabulary_filter_name,
show_speaker_label=show_speaker_label,
enable_channel_identification=enable_channel_identification,
number_of_channels=number_of_channels,
enable_partial_results_stabilization=enable_partial_results_stabilization,
partial_results_stability=partial_results_stability,
)
class AWSTranscribePresignedURL:
def __init__(
self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1"
):
self.access_key = access_key
self.secret_key = secret_key
self.session_token = session_token
self.method = "GET"
self.service = "transcribe"
self.region = region
self.endpoint = ""
self.host = ""
self.amz_date = ""
self.datestamp = ""
self.canonical_uri = "/stream-transcription-websocket"
self.canonical_headers = ""
self.signed_headers = "host"
self.algorithm = "AWS4-HMAC-SHA256"
self.credential_scope = ""
self.canonical_querystring = ""
self.payload_hash = ""
self.canonical_request = ""
self.string_to_sign = ""
self.signature = ""
self.request_url = ""
def get_request_url(
self,
sample_rate: int,
language_code: str = "",
media_encoding: str = "pcm",
vocabulary_name: str = "",
vocabulary_filter_name: str = "",
show_speaker_label: bool = False,
enable_channel_identification: bool = False,
number_of_channels: int = 1,
enable_partial_results_stabilization: bool = False,
partial_results_stability: str = "",
) -> str:
self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443"
self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443"
now = datetime.datetime.utcnow()
self.amz_date = now.strftime("%Y%m%dT%H%M%SZ")
self.datestamp = now.strftime("%Y%m%d")
self.canonical_headers = f"host:{self.host}\n"
self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request"
# Create canonical querystring
self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm
self.canonical_querystring += (
"&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope
)
self.canonical_querystring += "&X-Amz-Date=" + self.amz_date
self.canonical_querystring += "&X-Amz-Expires=300"
if self.session_token:
self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote(
self.session_token, safe=""
)
self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers
if enable_channel_identification:
self.canonical_querystring += "&enable-channel-identification=true"
if enable_partial_results_stabilization:
self.canonical_querystring += "&enable-partial-results-stabilization=true"
if language_code:
self.canonical_querystring += "&language-code=" + language_code
if media_encoding:
self.canonical_querystring += "&media-encoding=" + media_encoding
if number_of_channels > 1:
self.canonical_querystring += "&number-of-channels=" + str(number_of_channels)
if partial_results_stability:
self.canonical_querystring += "&partial-results-stability=" + partial_results_stability
if sample_rate:
self.canonical_querystring += "&sample-rate=" + str(sample_rate)
if show_speaker_label:
self.canonical_querystring += "&show-speaker-label=true"
if vocabulary_filter_name:
self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name
if vocabulary_name:
self.canonical_querystring += "&vocabulary-name=" + vocabulary_name
# Create payload hash
self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest()
# Create canonical request
self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}"
# Create string to sign
credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request"
string_to_sign = (
f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n"
+ hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest()
)
# Calculate signature
k_date = hmac.new(
f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256
).digest()
k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest()
k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest()
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
self.signature = hmac.new(
k_signing, string_to_sign.encode("utf-8"), hashlib.sha256
).hexdigest()
# Add signature to query string
self.canonical_querystring += "&X-Amz-Signature=" + self.signature
# Create request URL
self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring
return self.request_url
def get_headers(header_name: str, header_value: str) -> bytearray:
"""Build a header following AWS event stream format."""
name = header_name.encode("utf-8")
name_byte_length = bytes([len(name)])
value_type = bytes([7]) # 7 represents a string
value = header_value.encode("utf-8")
value_byte_length = struct.pack(">H", len(value))
# Construct the header
header_list = bytearray()
header_list.extend(name_byte_length)
header_list.extend(name)
header_list.extend(value_type)
header_list.extend(value_byte_length)
header_list.extend(value)
return header_list
def build_event_message(payload: bytes) -> bytes:
"""
Build an event message for AWS Transcribe streaming.
Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py
"""
# Build headers
content_type_header = get_headers(":content-type", "application/octet-stream")
event_type_header = get_headers(":event-type", "AudioEvent")
message_type_header = get_headers(":message-type", "event")
headers = bytearray()
headers.extend(content_type_header)
headers.extend(event_type_header)
headers.extend(message_type_header)
# Calculate total byte length and headers byte length
# 16 accounts for 8 byte prelude, 2x 4 byte CRCs
total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16)
headers_byte_length = struct.pack(">I", len(headers))
# Build the prelude
prelude = bytearray([0] * 8)
prelude[:4] = total_byte_length
prelude[4:] = headers_byte_length
# Calculate checksum for prelude
prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF)
# Construct the message
message_as_list = bytearray()
message_as_list.extend(prelude)
message_as_list.extend(prelude_crc)
message_as_list.extend(headers)
message_as_list.extend(payload)
# Calculate checksum for message
message = bytes(message_as_list)
message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF)
# Add message checksum
message_as_list.extend(message_crc)
return bytes(message_as_list)
def decode_event(message):
# Extract the prelude, headers, payload and CRC
prelude = message[:8]
total_length, headers_length = struct.unpack(">II", prelude)
prelude_crc = struct.unpack(">I", message[8:12])[0]
headers = message[12 : 12 + headers_length]
payload = message[12 + headers_length : -4]
message_crc = struct.unpack(">I", message[-4:])[0]
# Check the CRCs
assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed"
assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed"
# Parse the headers
headers_dict = {}
while headers:
name_len = headers[0]
name = headers[1 : 1 + name_len].decode("utf-8")
value_type = headers[1 + name_len]
value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0]
value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8")
headers_dict[name] = value
headers = headers[4 + name_len + value_len :]
return headers_dict, json.loads(payload)

View File

@@ -1 +1 @@
-e ".[anthropic,google,langchain]"
-e ".[anthropic,aws,google,langchain]"

View File

@@ -40,6 +40,11 @@ from pipecat.services.anthropic.llm import (
AnthropicLLMContext,
AnthropicUserContextAggregator,
)
from pipecat.services.aws.llm import (
AWSBedrockAssistantContextAggregator,
AWSBedrockLLMContext,
AWSBedrockUserContextAggregator,
)
from pipecat.services.google.llm import (
GoogleAssistantContextAggregator,
GoogleLLMContext,
@@ -669,26 +674,6 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola
AGGREGATOR_CLASS = LLMUserContextAggregator
#
# OpenAI
#
class TestOpenAIUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIUserContextAggregator
class TestOpenAIAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
#
# Anthropic
#
@@ -724,6 +709,43 @@ class TestAnthropicAssistantContextAggregator(
assert context.messages[index]["content"][0]["content"] == json.dumps(content)
#
# AWS (Bedrock)
#
class TestAWSBedrockUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AWSBedrockLLMContext
AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
messages = context.messages[content_index]
assert messages["content"][index]["text"] == content
class TestAWSBedrockAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AWSBedrockLLMContext
AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
messages = context.messages[content_index]
assert messages["content"][index]["text"] == content
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
assert context.messages[index]["content"][0]["toolResult"]["content"][0][
"text"
] == json.dumps(content)
#
# Google
#
@@ -766,3 +788,23 @@ class TestGoogleAssistantContextAggregator(
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
obj = glm.Content.to_dict(context.messages[index])
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)
#
# OpenAI
#
class TestOpenAIUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIUserContextAggregator
class TestOpenAIAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]

View File

@@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionToolParam
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
@@ -174,3 +175,32 @@ class TestFunctionAdapters(unittest.TestCase):
tools_def = self.tools_def
tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]}
assert GeminiLLMAdapter().to_provider_tools_format(tools_def) == expected
def test_bedrock_adapter(self):
"""Test AWS Bedrock adapter format transformation."""
expected = [
{
"toolSpec": {
"name": "get_weather",
"description": "Get the weather in a given location",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use.",
},
"location": {
"type": "string",
"description": "The city, e.g. San Francisco",
},
},
"required": ["location", "format"],
}
},
}
}
]
assert AWSBedrockLLMAdapter().to_provider_tools_format(self.tools_def) == expected