Merge pull request #1753 from pipecat-ai/aleix/add-bedrock-support
Add support for Amazon Bedrock LLMs
This commit is contained in:
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added new AWS services `AWSBedrockLLMService` and `AWSTranscribeSTTService`.
|
||||
|
||||
- Added `on_active_speaker_changed` event handler to the `DailyTransport` class.
|
||||
|
||||
- Added `enable_ssml_parsing` and `enable_logging` to `InputParams` in
|
||||
@@ -25,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `PollyTTSService` is now deprecated, use `AWSPollyTTSService` instead.
|
||||
|
||||
- Observer `on_push_frame(src, dst, frame, direction, timestamp)` is now
|
||||
deprecated, use `on_push_frame(data: FramePushed)` instead.
|
||||
|
||||
|
||||
24
README.md
24
README.md
@@ -49,18 +49,18 @@ You can connect to Pipecat from any platform using our official SDKs:
|
||||
|
||||
## 🧩 Available services
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
|
||||
| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
|
||||
| Category | Services |
|
||||
|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
|
||||
| Analytics & Metrics | [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
|
||||
|
||||
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
@@ -15,9 +14,9 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.aws.tts import PollyTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.aws.stt import AWSTranscribeSTTService
|
||||
from pipecat.services.aws.tts import AWSPollyTTSService
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
@@ -37,17 +36,19 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
),
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = AWSTranscribeSTTService()
|
||||
|
||||
tts = PollyTTSService(
|
||||
api_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
region=os.getenv("AWS_REGION"),
|
||||
voice_id="Amy",
|
||||
params=PollyTTSService.InputParams(engine="neural", language="en-GB", rate="1.05"),
|
||||
tts = AWSPollyTTSService(
|
||||
region="us-west-2", # only specific regions support generative TTS
|
||||
voice_id="Joanna",
|
||||
params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region="us-west-2",
|
||||
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -85,7 +86,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
139
examples/foundational/14r-function-calling-aws.py
Normal file
139
examples/foundational/14r-function-calling-aws.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.aws.stt import AWSTranscribeSTTService
|
||||
from pipecat.services.aws.tts import AWSPollyTTSService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
transport = SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
stt = AWSTranscribeSTTService()
|
||||
|
||||
tts = AWSPollyTTSService(
|
||||
region="us-west-2", # only specific regions support generative TTS
|
||||
voice_id="Joanna",
|
||||
params=AWSPollyTTSService.InputParams(engine="generative", rate="1.1"),
|
||||
)
|
||||
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region="us-west-2",
|
||||
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8, latency="optimized"),
|
||||
)
|
||||
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from run import main
|
||||
|
||||
main()
|
||||
@@ -41,7 +41,7 @@ Website = "https://pipecat.ai"
|
||||
[project.optional-dependencies]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "assemblyai~=0.37.0" ]
|
||||
aws = [ "boto3~=1.37.16" ]
|
||||
aws = [ "boto3~=1.37.16", "websockets~=13.1" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ]
|
||||
cerebras = []
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
|
||||
38
src/pipecat/adapters/services/bedrock_adapter.py
Normal file
38
src/pipecat/adapters/services/bedrock_adapter.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
|
||||
class AWSBedrockLLMAdapter(BaseLLMAdapter):
|
||||
@staticmethod
|
||||
def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
return {
|
||||
"toolSpec": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": function.properties,
|
||||
"required": function.required,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
|
||||
"""Converts function schemas to Bedrock's function-calling format.
|
||||
|
||||
:return: Bedrock formatted function call definition.
|
||||
"""
|
||||
|
||||
functions_schema = tools_schema.standard_tools
|
||||
return [self._to_bedrock_function_format(func) for func in functions_schema]
|
||||
@@ -8,6 +8,8 @@ import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts")
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.[llm,stt,tts]")
|
||||
|
||||
785
src/pipecat/services/aws/llm.py
Normal file
785
src/pipecat/services/aws/llm.py
Normal file
@@ -0,0 +1,785 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
try:
|
||||
import boto3
|
||||
import httpx
|
||||
from botocore.config import Config
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSBedrockContextAggregatorPair:
|
||||
_user: "AWSBedrockUserContextAggregator"
|
||||
_assistant: "AWSBedrockAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "AWSBedrockUserContextAggregator":
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "AWSBedrockAssistantContextAggregator":
|
||||
return self._assistant
|
||||
|
||||
|
||||
class AWSBedrockLLMContext(OpenAILLMContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[dict]] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[dict] = None,
|
||||
*,
|
||||
system: Optional[str] = None,
|
||||
):
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self.system = system
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext":
|
||||
logger.debug(f"Upgrading to AWS Bedrock: {obj}")
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext):
|
||||
obj.__class__ = AWSBedrockLLMContext
|
||||
obj._restructure_from_openai_messages()
|
||||
else:
|
||||
obj._restructure_from_bedrock_messages()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_openai_context(cls, openai_context: OpenAILLMContext):
|
||||
self = cls(
|
||||
messages=openai_context.messages,
|
||||
tools=openai_context.tools,
|
||||
tool_choice=openai_context.tool_choice,
|
||||
)
|
||||
self.set_llm_adapter(openai_context.get_llm_adapter())
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext":
|
||||
self = cls(messages=messages)
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AWSBedrockLLMContext":
|
||||
context = cls()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
return context
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
# convert a message in AWS Bedrock format into one or more messages in OpenAI format
|
||||
def to_standard_messages(self, obj):
|
||||
"""Convert AWS Bedrock message format to standard structured format.
|
||||
|
||||
Handles text content and function calls for both user and assistant messages.
|
||||
|
||||
Args:
|
||||
obj: Message in AWS Bedrock format:
|
||||
{
|
||||
"role": "user/assistant",
|
||||
"content": [{"text": str} | {"toolUse": {...}} | {"toolResult": {...}}]
|
||||
}
|
||||
|
||||
Returns:
|
||||
List of messages in standard format:
|
||||
[
|
||||
{
|
||||
"role": "user/assistant/tool",
|
||||
"content": [{"type": "text", "text": str}]
|
||||
}
|
||||
]
|
||||
"""
|
||||
role = obj.get("role")
|
||||
content = obj.get("content")
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif "toolUse" in item:
|
||||
tool_use = item["toolUse"]
|
||||
tool_items.append(
|
||||
{
|
||||
"type": "function",
|
||||
"id": tool_use["toolUseId"],
|
||||
"function": {
|
||||
"name": tool_use["name"],
|
||||
"arguments": json.dumps(tool_use["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
if tool_items:
|
||||
messages.append({"role": role, "tool_calls": tool_items})
|
||||
return messages
|
||||
elif role == "user":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif "toolResult" in item:
|
||||
tool_result = item["toolResult"]
|
||||
# Extract content from toolResult
|
||||
result_content = ""
|
||||
if isinstance(tool_result["content"], list):
|
||||
for content_item in tool_result["content"]:
|
||||
if "text" in content_item:
|
||||
result_content = content_item["text"]
|
||||
elif "json" in content_item:
|
||||
result_content = json.dumps(content_item["json"])
|
||||
else:
|
||||
result_content = tool_result["content"]
|
||||
|
||||
tool_items.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_result["toolUseId"],
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
messages.extend(tool_items)
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert standard format message to AWS Bedrock format.
|
||||
|
||||
Handles conversion of text content, tool calls, and tool results.
|
||||
Empty text content is converted to "(empty)".
|
||||
|
||||
Args:
|
||||
message: Message in standard format:
|
||||
{
|
||||
"role": "user/assistant/tool",
|
||||
"content": str | [{"type": "text", ...}],
|
||||
"tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}]
|
||||
}
|
||||
|
||||
Returns:
|
||||
Message in AWS Bedrock format:
|
||||
{
|
||||
"role": "user/assistant",
|
||||
"content": [
|
||||
{"text": str} |
|
||||
{"toolUse": {"toolUseId": str, "name": str, "input": dict}} |
|
||||
{"toolResult": {"toolUseId": str, "content": [...], "status": str}}
|
||||
]
|
||||
}
|
||||
"""
|
||||
if message["role"] == "tool":
|
||||
# Try to parse the content as JSON if it looks like JSON
|
||||
try:
|
||||
if message["content"].strip().startswith("{") and message[
|
||||
"content"
|
||||
].strip().endswith("}"):
|
||||
content_json = json.loads(message["content"])
|
||||
tool_result_content = [{"json": content_json}]
|
||||
else:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
except:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message["tool_call_id"],
|
||||
"content": tool_result_content,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"toolUse": {
|
||||
"toolUseId": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
|
||||
# Handle text content
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
return {"role": message["role"], "content": [{"text": "(empty)"}]}
|
||||
else:
|
||||
return {"role": message["role"], "content": [{"text": content}]}
|
||||
elif isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
if item.get("type", "") == "text":
|
||||
text_content = item["text"] if item["text"] != "" else "(empty)"
|
||||
new_content.append({"text": text_content})
|
||||
return {"role": message["role"], "content": new_content}
|
||||
|
||||
return message
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# Image should be the first content block in the message
|
||||
content = [{"type": "image", "format": "jpeg", "source": {"bytes": encoded_image}}]
|
||||
if text:
|
||||
content.append({"text": text})
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
def add_message(self, message):
|
||||
try:
|
||||
if self.messages:
|
||||
# AWS Bedrock requires that roles alternate. If this message's
|
||||
# role is the same as the last message, we should add this
|
||||
# message's content to the last message.
|
||||
if self.messages[-1]["role"] == message["role"]:
|
||||
# if the last message has just a content string, convert it to a list
|
||||
# in the proper format
|
||||
if isinstance(self.messages[-1]["content"], str):
|
||||
self.messages[-1]["content"] = [{"text": self.messages[-1]["content"]}]
|
||||
# if this message has just a content string, convert it to a list
|
||||
# in the proper format
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = [{"text": message["content"]}]
|
||||
# append the content of this message to the last message
|
||||
self.messages[-1]["content"].extend(message["content"])
|
||||
else:
|
||||
self.messages.append(message)
|
||||
else:
|
||||
self.messages.append(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
|
||||
def _restructure_from_bedrock_messages(self):
|
||||
"""Restructure messages in AWS Bedrock format by handling system
|
||||
messages, merging consecutive messages with the same role, and ensuring
|
||||
proper content formatting.
|
||||
|
||||
"""
|
||||
# Handle system message if present at the beginning
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
if len(self.messages) == 1:
|
||||
self.messages[0]["role"] = "user"
|
||||
else:
|
||||
system_content = self.messages.pop(0)["content"]
|
||||
if isinstance(system_content, str):
|
||||
system_content = [{"text": system_content}]
|
||||
|
||||
if self.system:
|
||||
if isinstance(self.system, str):
|
||||
self.system = [{"text": self.system}]
|
||||
self.system.extend(system_content)
|
||||
else:
|
||||
self.system = system_content
|
||||
|
||||
# Ensure content is properly formatted
|
||||
for msg in self.messages:
|
||||
if isinstance(msg["content"], str):
|
||||
msg["content"] = [{"text": msg["content"]}]
|
||||
elif not msg["content"]:
|
||||
msg["content"] = [{"text": "(empty)"}]
|
||||
elif isinstance(msg["content"], list):
|
||||
for idx, item in enumerate(msg["content"]):
|
||||
if isinstance(item, dict) and "text" in item and item["text"] == "":
|
||||
item["text"] = "(empty)"
|
||||
elif isinstance(item, str) and item == "":
|
||||
msg["content"][idx] = {"text": "(empty)"}
|
||||
|
||||
# Merge consecutive messages with the same role
|
||||
merged_messages = []
|
||||
for msg in self.messages:
|
||||
if merged_messages and merged_messages[-1]["role"] == msg["role"]:
|
||||
merged_messages[-1]["content"].extend(msg["content"])
|
||||
else:
|
||||
merged_messages.append(msg)
|
||||
|
||||
self.messages.clear()
|
||||
self.messages.extend(merged_messages)
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
|
||||
try:
|
||||
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our context.messages list. (For
|
||||
# compatibility with Open AI messages format.)
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
self.system = self.messages[0]["content"]
|
||||
self.messages.pop(0)
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(self.messages) - 1:
|
||||
current_message = self.messages[i]
|
||||
next_message = self.messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
self.messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in self.messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
if self.system:
|
||||
messages.insert(0, {"role": "system", "content": self.system})
|
||||
return messages
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item.get("image"):
|
||||
item["source"]["bytes"] = "..."
|
||||
msgs.append(msg)
|
||||
return json.dumps(msgs)
|
||||
|
||||
|
||||
class AWSBedrockUserContextAggregator(LLMUserContextAggregator):
|
||||
pass
|
||||
|
||||
|
||||
class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
# Format tool use according to AWS Bedrock API
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": frame.tool_call_id,
|
||||
"name": frame.function_name,
|
||||
"input": frame.arguments if frame.arguments else {},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": frame.tool_call_id,
|
||||
"content": [{"text": "IN_PROGRESS"}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result)
|
||||
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if message["role"] == "user":
|
||||
for content in message["content"]:
|
||||
if (
|
||||
isinstance(content, dict)
|
||||
and content.get("toolResult")
|
||||
and content["toolResult"]["toolUseId"] == tool_call_id
|
||||
):
|
||||
content["toolResult"]["content"] = [{"text": result}]
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
|
||||
class AWSBedrockLLMService(LLMService):
|
||||
"""This class implements inference with AWS Bedrock models including Amazon
|
||||
Nova and Anthropic Claude.
|
||||
|
||||
Requires AWS credentials to be configured in the environment or through
|
||||
boto3 configuration.
|
||||
|
||||
"""
|
||||
|
||||
# Overriding the default adapter to use the Anthropic one.
|
||||
adapter_class = AWSBedrockLLMAdapter
|
||||
|
||||
class InputParams(BaseModel):
|
||||
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
|
||||
temperature: Optional[float] = Field(default_factory=lambda: 0.7, ge=0.0, le=1.0)
|
||||
top_p: Optional[float] = Field(default_factory=lambda: 0.999, ge=0.0, le=1.0)
|
||||
stop_sequences: Optional[List[str]] = Field(default_factory=lambda: [])
|
||||
latency: Optional[str] = Field(default_factory=lambda: "standard")
|
||||
additional_model_request_fields: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aws_access_key: Optional[str] = None,
|
||||
aws_secret_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region: str = "us-east-1",
|
||||
model: str,
|
||||
params: InputParams = InputParams(),
|
||||
client_config: Optional[Config] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Initialize the AWS Bedrock client
|
||||
if not client_config:
|
||||
client_config = Config(
|
||||
connect_timeout=300, # 5 minutes
|
||||
read_timeout=300, # 5 minutes
|
||||
retries={"max_attempts": 3},
|
||||
)
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key,
|
||||
aws_secret_access_key=aws_secret_key,
|
||||
aws_session_token=aws_session_token,
|
||||
region_name=aws_region,
|
||||
)
|
||||
self._client = session.client(service_name="bedrock-runtime", config=client_config)
|
||||
|
||||
self.set_model_name(model)
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"latency": params.latency,
|
||||
"additional_model_request_fields": params.additional_model_request_fields
|
||||
if isinstance(params.additional_model_request_fields, dict)
|
||||
else {},
|
||||
}
|
||||
|
||||
logger.info(f"Using AWS Bedrock model: {model}")
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> AWSBedrockContextAggregatorPair:
|
||||
"""Create an instance of AWSBedrockContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
assistant aggregators can be provided.
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator
|
||||
parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User
|
||||
aggregator parameters.
|
||||
|
||||
Returns:
|
||||
AWSBedrockContextAggregatorPair: A pair of context aggregators, one
|
||||
for the user and one for the assistant, encapsulated in an
|
||||
AWSBedrockContextAggregatorPair.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = AWSBedrockLLMContext.from_openai_context(context)
|
||||
|
||||
user = AWSBedrockUserContextAggregator(context, params=user_params)
|
||||
assistant = AWSBedrockAssistantContextAggregator(context, params=assistant_params)
|
||||
return AWSBedrockContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def _process_context(self, context: AWSBedrockLLMContext):
|
||||
# Usage tracking
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
completion_tokens_estimate = 0
|
||||
cache_read_input_tokens = 0
|
||||
cache_creation_input_tokens = 0
|
||||
use_completion_tokens_estimate = False
|
||||
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Set up inference config
|
||||
inference_config = {
|
||||
"maxTokens": self._settings["max_tokens"],
|
||||
"temperature": self._settings["temperature"],
|
||||
"topP": self._settings["top_p"],
|
||||
}
|
||||
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": self.model_name,
|
||||
"messages": context.messages,
|
||||
"inferenceConfig": inference_config,
|
||||
"additionalModelRequestFields": self._settings["additional_model_request_fields"],
|
||||
}
|
||||
|
||||
# Add system message
|
||||
request_params["system"] = context.system
|
||||
|
||||
# Add tools if present
|
||||
if context.tools:
|
||||
tool_config = {"tools": context.tools}
|
||||
|
||||
# Add tool_choice if specified
|
||||
if context.tool_choice:
|
||||
if context.tool_choice == "auto":
|
||||
tool_config["toolChoice"] = {"auto": {}}
|
||||
elif context.tool_choice == "none":
|
||||
# Skip adding toolChoice for "none"
|
||||
pass
|
||||
elif (
|
||||
isinstance(context.tool_choice, dict) and "function" in context.tool_choice
|
||||
):
|
||||
tool_config["toolChoice"] = {
|
||||
"tool": {"name": context.tool_choice["function"]["name"]}
|
||||
}
|
||||
|
||||
request_params["toolConfig"] = tool_config
|
||||
|
||||
# Add performance config if latency is specified
|
||||
if self._settings["latency"] in ["standard", "optimized"]:
|
||||
request_params["performanceConfig"] = {"latency": self._settings["latency"]}
|
||||
|
||||
logger.debug(f"Calling AWS Bedrock model with: {request_params}")
|
||||
|
||||
# Call AWS Bedrock with streaming
|
||||
response = self._client.converse_stream(**request_params)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
# Process the streaming response
|
||||
tool_use_block = None
|
||||
json_accumulator = ""
|
||||
|
||||
for event in response["stream"]:
|
||||
# Handle text content
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
await self.push_frame(LLMTextFrame(delta["text"]))
|
||||
completion_tokens_estimate += self._estimate_tokens(delta["text"])
|
||||
elif "toolUse" in delta and "input" in delta["toolUse"]:
|
||||
# Handle partial JSON for tool use
|
||||
json_accumulator += delta["toolUse"]["input"]
|
||||
completion_tokens_estimate += self._estimate_tokens(
|
||||
delta["toolUse"]["input"]
|
||||
)
|
||||
|
||||
# Handle tool use start
|
||||
elif "contentBlockStart" in event:
|
||||
content_block_start = event["contentBlockStart"]["start"]
|
||||
if "toolUse" in content_block_start:
|
||||
tool_use_block = {
|
||||
"id": content_block_start["toolUse"].get("toolUseId", ""),
|
||||
"name": content_block_start["toolUse"].get("name", ""),
|
||||
}
|
||||
json_accumulator = ""
|
||||
|
||||
# Handle message completion with tool use
|
||||
elif "messageStop" in event and "stopReason" in event["messageStop"]:
|
||||
if event["messageStop"]["stopReason"] == "tool_use" and tool_use_block:
|
||||
try:
|
||||
arguments = json.loads(json_accumulator) if json_accumulator else {}
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=tool_use_block["id"],
|
||||
function_name=tool_use_block["name"],
|
||||
arguments=arguments,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse tool arguments: {json_accumulator}")
|
||||
|
||||
# Handle usage metrics if available
|
||||
if "metadata" in event and "usage" in event["metadata"]:
|
||||
usage = event["metadata"]["usage"]
|
||||
prompt_tokens += usage.get("inputTokens", 0)
|
||||
completion_tokens += usage.get("outputTokens", 0)
|
||||
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
|
||||
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
|
||||
# token estimate. The reraise the exception so all the processors running in this task
|
||||
# also get cancelled.
|
||||
use_completion_tokens_estimate = True
|
||||
raise
|
||||
except httpx.TimeoutException:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
comp_tokens = (
|
||||
completion_tokens
|
||||
if not use_completion_tokens_estimate
|
||||
else completion_tokens_estimate
|
||||
)
|
||||
await self._report_usage_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=comp_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context)
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = AWSBedrockLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
# This is only useful in very simple pipelines because it creates
|
||||
# a new context. Generally we want a context manager to catch
|
||||
# UserImageRawFrames coming through the pipeline and add them
|
||||
# to the context.
|
||||
context = AWSBedrockLLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
return int(len(re.split(r"[^\w]+", text)) * 1.3)
|
||||
|
||||
async def _report_usage_metrics(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
):
|
||||
if prompt_tokens or completion_tokens:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
329
src/pipecat/services/aws/stt.py
Normal file
329
src/pipecat/services/aws/stt.py
Normal file
@@ -0,0 +1,329 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AWSTranscribeSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
region: Optional[str] = "us-east-1",
|
||||
sample_rate: int = 16000,
|
||||
language: Language = Language.EN,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"language": language,
|
||||
"media_encoding": "linear16", # AWS expects raw PCM
|
||||
"number_of_channels": 1,
|
||||
"show_speaker_label": False,
|
||||
"enable_channel_identification": False,
|
||||
}
|
||||
|
||||
# Validate sample rate - AWS Transcribe only supports 8000 Hz or 16000 Hz
|
||||
if sample_rate not in [8000, 16000]:
|
||||
logger.warning(
|
||||
f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. Converting from {sample_rate} Hz to 16000 Hz."
|
||||
)
|
||||
self._settings["sample_rate"] = 16000
|
||||
|
||||
self._credentials = {
|
||||
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
"region": region or os.getenv("AWS_REGION", "us-east-1"),
|
||||
}
|
||||
|
||||
self._ws_client = None
|
||||
self._connection_lock = asyncio.Lock()
|
||||
self._connecting = False
|
||||
self._receive_task = None
|
||||
|
||||
def get_service_encoding(self, encoding: str) -> str:
|
||||
"""Convert internal encoding format to AWS Transcribe format."""
|
||||
encoding_map = {
|
||||
"linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM
|
||||
}
|
||||
return encoding_map.get(encoding, encoding)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Initialize the connection when the service starts."""
|
||||
await super().start(frame)
|
||||
logger.info("Starting AWS Transcribe service...")
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self._connect()
|
||||
if self._ws_client and self._ws_client.open:
|
||||
logger.info("Successfully established WebSocket connection")
|
||||
return
|
||||
logger.warning("WebSocket connection not established after connect")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}")
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
await asyncio.sleep(1) # Wait before retrying
|
||||
|
||||
raise RuntimeError("Failed to establish WebSocket connection after multiple attempts")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data and send to AWS Transcribe"""
|
||||
try:
|
||||
# Ensure WebSocket is connected
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
logger.debug("WebSocket not connected, attempting to reconnect...")
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reconnect: {e}")
|
||||
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
|
||||
return
|
||||
|
||||
# Format the audio data according to AWS event stream format
|
||||
event_message = build_event_message(audio)
|
||||
|
||||
# Send the formatted event message
|
||||
try:
|
||||
await self._ws_client.send(event_message)
|
||||
# Start metrics after first chunk sent
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(f"Connection closed while sending: {e}")
|
||||
await self._disconnect()
|
||||
# Don't yield error here - we'll retry on next frame
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
|
||||
await self._disconnect()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_stt: {e}")
|
||||
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to AWS Transcribe with connection state management."""
|
||||
if self._ws_client and self._ws_client.open and self._receive_task:
|
||||
logger.debug(f"{self} Already connected")
|
||||
return
|
||||
|
||||
async with self._connection_lock:
|
||||
if self._connecting:
|
||||
logger.debug(f"{self} Connection already in progress")
|
||||
return
|
||||
|
||||
try:
|
||||
self._connecting = True
|
||||
logger.debug(f"{self} Starting connection process...")
|
||||
|
||||
if self._ws_client:
|
||||
await self._disconnect()
|
||||
|
||||
language_code = self.language_to_service_language(
|
||||
Language(self._settings["language"])
|
||||
)
|
||||
if not language_code:
|
||||
raise ValueError(f"Unsupported language: {self._settings['language']}")
|
||||
|
||||
# Generate random websocket key
|
||||
websocket_key = "".join(
|
||||
random.choices(
|
||||
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
|
||||
)
|
||||
)
|
||||
|
||||
# Add required headers
|
||||
extra_headers = {
|
||||
"Origin": "https://localhost",
|
||||
"Sec-WebSocket-Key": websocket_key,
|
||||
"Sec-WebSocket-Version": "13",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
# Get presigned URL
|
||||
presigned_url = get_presigned_url(
|
||||
region=self._credentials["region"],
|
||||
credentials={
|
||||
"access_key": self._credentials["aws_access_key_id"],
|
||||
"secret_key": self._credentials["aws_secret_access_key"],
|
||||
"session_token": self._credentials["aws_session_token"],
|
||||
},
|
||||
language_code=language_code,
|
||||
media_encoding=self.get_service_encoding(
|
||||
self._settings["media_encoding"]
|
||||
), # Convert to AWS format
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
number_of_channels=self._settings["number_of_channels"],
|
||||
enable_partial_results_stabilization=True,
|
||||
partial_results_stability="high",
|
||||
show_speaker_label=self._settings["show_speaker_label"],
|
||||
enable_channel_identification=self._settings["enable_channel_identification"],
|
||||
)
|
||||
|
||||
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
|
||||
|
||||
# Connect with the required headers and settings
|
||||
self._ws_client = await websockets.connect(
|
||||
presigned_url,
|
||||
extra_headers=extra_headers,
|
||||
subprotocols=["mqtt"],
|
||||
ping_interval=None,
|
||||
ping_timeout=None,
|
||||
compression=None,
|
||||
)
|
||||
|
||||
logger.debug(f"{self} WebSocket connected, starting receive task...")
|
||||
|
||||
# Start receive task
|
||||
self._receive_task = self.create_task(self._receive_loop())
|
||||
|
||||
logger.info(f"{self} Successfully connected to AWS Transcribe")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
|
||||
await self._disconnect()
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._connecting = False
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from AWS Transcribe."""
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
try:
|
||||
if self._ws_client and self._ws_client.open:
|
||||
# Send end-stream message
|
||||
end_stream = {"message-type": "event", "event": "end"}
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} Error closing WebSocket connection: {e}")
|
||||
finally:
|
||||
self._ws_client = None
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
"""Convert internal language enum to AWS Transcribe language code."""
|
||||
language_map = {
|
||||
Language.EN: "en-US",
|
||||
Language.ES: "es-US",
|
||||
Language.FR: "fr-FR",
|
||||
Language.DE: "de-DE",
|
||||
Language.IT: "it-IT",
|
||||
Language.PT: "pt-BR",
|
||||
Language.JA: "ja-JP",
|
||||
Language.KO: "ko-KR",
|
||||
Language.ZH: "zh-CN",
|
||||
}
|
||||
return language_map.get(language)
|
||||
|
||||
async def _receive_loop(self):
|
||||
"""Background task to receive and process messages from AWS Transcribe."""
|
||||
while True:
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
logger.warning(f"{self} WebSocket closed in receive loop")
|
||||
break
|
||||
|
||||
try:
|
||||
response = await self._ws_client.recv()
|
||||
headers, payload = decode_event(response)
|
||||
|
||||
if headers.get(":message-type") == "event":
|
||||
# Process transcription results
|
||||
results = payload.get("Transcript", {}).get("Results", [])
|
||||
if results:
|
||||
result = results[0]
|
||||
alternatives = result.get("Alternatives", [])
|
||||
if alternatives:
|
||||
transcript = alternatives[0].get("Transcript", "")
|
||||
is_final = not result.get("IsPartial", True)
|
||||
|
||||
if transcript:
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
)
|
||||
)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
)
|
||||
)
|
||||
elif headers.get(":message-type") == "exception":
|
||||
error_msg = payload.get("Message", "Unknown error")
|
||||
logger.error(f"{self} Exception from AWS: {error_msg}")
|
||||
await self.push_frame(
|
||||
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
|
||||
)
|
||||
else:
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.error(
|
||||
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
break
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
@@ -26,9 +27,7 @@ try:
|
||||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Deepgram, you need to `pip install pipecat-ai[aws]`. Also, set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
|
||||
)
|
||||
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -108,7 +107,7 @@ def language_to_aws_language(language: Language) -> Optional[str]:
|
||||
return language_map.get(language)
|
||||
|
||||
|
||||
class PollyTTSService(TTSService):
|
||||
class AWSPollyTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
engine: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
@@ -151,6 +150,24 @@ class PollyTTSService(TTSService):
|
||||
|
||||
self.set_voice(voice_id)
|
||||
|
||||
# Get credentials from environment variables if not provided
|
||||
self._credentials = {
|
||||
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
"region": region or os.getenv("AWS_REGION", "us-east-1"),
|
||||
}
|
||||
|
||||
# Validate that we have the required credentials
|
||||
if (
|
||||
not self._credentials["aws_access_key_id"]
|
||||
or not self._credentials["aws_secret_access_key"]
|
||||
):
|
||||
raise ValueError(
|
||||
"AWS credentials not found. Please provide them either through constructor parameters "
|
||||
"or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -165,18 +182,17 @@ class PollyTTSService(TTSService):
|
||||
|
||||
prosody_attrs = []
|
||||
# Prosody tags are only supported for standard and neural engines
|
||||
if self._settings["engine"] != "generative":
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["engine"] == "standard":
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
else:
|
||||
logger.warning("Prosody tags are not supported for generative engine. Ignoring.")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
ssml += text
|
||||
|
||||
@@ -187,6 +203,8 @@ class PollyTTSService(TTSService):
|
||||
|
||||
ssml += "</speak>"
|
||||
|
||||
logger.trace(f"{self} SSML: {ssml}")
|
||||
|
||||
return ssml
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -248,3 +266,17 @@ class PollyTTSService(TTSService):
|
||||
|
||||
finally:
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
class PollyTTSService(AWSPollyTTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"'PollyTTSService' is deprecated, use 'AWSPollyTTSService' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
261
src/pipecat/services/aws/utils.py
Normal file
261
src/pipecat/services/aws/utils.py
Normal file
@@ -0,0 +1,261 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import binascii
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import struct
|
||||
import urllib.parse
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def get_presigned_url(
|
||||
*,
|
||||
region: str,
|
||||
credentials: Dict[str, Optional[str]],
|
||||
language_code: str,
|
||||
media_encoding: str = "pcm",
|
||||
sample_rate: int = 16000,
|
||||
number_of_channels: int = 1,
|
||||
enable_partial_results_stabilization: bool = True,
|
||||
partial_results_stability: str = "high",
|
||||
vocabulary_name: Optional[str] = None,
|
||||
vocabulary_filter_name: Optional[str] = None,
|
||||
show_speaker_label: bool = False,
|
||||
enable_channel_identification: bool = False,
|
||||
) -> str:
|
||||
"""Create a presigned URL for AWS Transcribe streaming."""
|
||||
access_key = credentials.get("access_key")
|
||||
secret_key = credentials.get("secret_key")
|
||||
session_token = credentials.get("session_token")
|
||||
|
||||
if not access_key or not secret_key:
|
||||
raise ValueError("AWS credentials are required")
|
||||
|
||||
# Initialize the URL generator
|
||||
url_generator = AWSTranscribePresignedURL(
|
||||
access_key=access_key, secret_key=secret_key, session_token=session_token, region=region
|
||||
)
|
||||
|
||||
# Get the presigned URL
|
||||
return url_generator.get_request_url(
|
||||
sample_rate=sample_rate,
|
||||
language_code=language_code,
|
||||
media_encoding=media_encoding,
|
||||
vocabulary_name=vocabulary_name,
|
||||
vocabulary_filter_name=vocabulary_filter_name,
|
||||
show_speaker_label=show_speaker_label,
|
||||
enable_channel_identification=enable_channel_identification,
|
||||
number_of_channels=number_of_channels,
|
||||
enable_partial_results_stabilization=enable_partial_results_stabilization,
|
||||
partial_results_stability=partial_results_stability,
|
||||
)
|
||||
|
||||
|
||||
class AWSTranscribePresignedURL:
|
||||
def __init__(
|
||||
self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1"
|
||||
):
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self.session_token = session_token
|
||||
self.method = "GET"
|
||||
self.service = "transcribe"
|
||||
self.region = region
|
||||
self.endpoint = ""
|
||||
self.host = ""
|
||||
self.amz_date = ""
|
||||
self.datestamp = ""
|
||||
self.canonical_uri = "/stream-transcription-websocket"
|
||||
self.canonical_headers = ""
|
||||
self.signed_headers = "host"
|
||||
self.algorithm = "AWS4-HMAC-SHA256"
|
||||
self.credential_scope = ""
|
||||
self.canonical_querystring = ""
|
||||
self.payload_hash = ""
|
||||
self.canonical_request = ""
|
||||
self.string_to_sign = ""
|
||||
self.signature = ""
|
||||
self.request_url = ""
|
||||
|
||||
def get_request_url(
|
||||
self,
|
||||
sample_rate: int,
|
||||
language_code: str = "",
|
||||
media_encoding: str = "pcm",
|
||||
vocabulary_name: str = "",
|
||||
vocabulary_filter_name: str = "",
|
||||
show_speaker_label: bool = False,
|
||||
enable_channel_identification: bool = False,
|
||||
number_of_channels: int = 1,
|
||||
enable_partial_results_stabilization: bool = False,
|
||||
partial_results_stability: str = "",
|
||||
) -> str:
|
||||
self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443"
|
||||
self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443"
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
self.amz_date = now.strftime("%Y%m%dT%H%M%SZ")
|
||||
self.datestamp = now.strftime("%Y%m%d")
|
||||
self.canonical_headers = f"host:{self.host}\n"
|
||||
self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request"
|
||||
|
||||
# Create canonical querystring
|
||||
self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm
|
||||
self.canonical_querystring += (
|
||||
"&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope
|
||||
)
|
||||
self.canonical_querystring += "&X-Amz-Date=" + self.amz_date
|
||||
self.canonical_querystring += "&X-Amz-Expires=300"
|
||||
if self.session_token:
|
||||
self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote(
|
||||
self.session_token, safe=""
|
||||
)
|
||||
self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers
|
||||
|
||||
if enable_channel_identification:
|
||||
self.canonical_querystring += "&enable-channel-identification=true"
|
||||
if enable_partial_results_stabilization:
|
||||
self.canonical_querystring += "&enable-partial-results-stabilization=true"
|
||||
if language_code:
|
||||
self.canonical_querystring += "&language-code=" + language_code
|
||||
if media_encoding:
|
||||
self.canonical_querystring += "&media-encoding=" + media_encoding
|
||||
if number_of_channels > 1:
|
||||
self.canonical_querystring += "&number-of-channels=" + str(number_of_channels)
|
||||
if partial_results_stability:
|
||||
self.canonical_querystring += "&partial-results-stability=" + partial_results_stability
|
||||
if sample_rate:
|
||||
self.canonical_querystring += "&sample-rate=" + str(sample_rate)
|
||||
if show_speaker_label:
|
||||
self.canonical_querystring += "&show-speaker-label=true"
|
||||
if vocabulary_filter_name:
|
||||
self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name
|
||||
if vocabulary_name:
|
||||
self.canonical_querystring += "&vocabulary-name=" + vocabulary_name
|
||||
|
||||
# Create payload hash
|
||||
self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest()
|
||||
|
||||
# Create canonical request
|
||||
self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}"
|
||||
|
||||
# Create string to sign
|
||||
credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request"
|
||||
string_to_sign = (
|
||||
f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n"
|
||||
+ hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest()
|
||||
)
|
||||
|
||||
# Calculate signature
|
||||
k_date = hmac.new(
|
||||
f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest()
|
||||
k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest()
|
||||
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
|
||||
self.signature = hmac.new(
|
||||
k_signing, string_to_sign.encode("utf-8"), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Add signature to query string
|
||||
self.canonical_querystring += "&X-Amz-Signature=" + self.signature
|
||||
|
||||
# Create request URL
|
||||
self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring
|
||||
return self.request_url
|
||||
|
||||
|
||||
def get_headers(header_name: str, header_value: str) -> bytearray:
|
||||
"""Build a header following AWS event stream format."""
|
||||
name = header_name.encode("utf-8")
|
||||
name_byte_length = bytes([len(name)])
|
||||
value_type = bytes([7]) # 7 represents a string
|
||||
value = header_value.encode("utf-8")
|
||||
value_byte_length = struct.pack(">H", len(value))
|
||||
|
||||
# Construct the header
|
||||
header_list = bytearray()
|
||||
header_list.extend(name_byte_length)
|
||||
header_list.extend(name)
|
||||
header_list.extend(value_type)
|
||||
header_list.extend(value_byte_length)
|
||||
header_list.extend(value)
|
||||
return header_list
|
||||
|
||||
|
||||
def build_event_message(payload: bytes) -> bytes:
|
||||
"""
|
||||
Build an event message for AWS Transcribe streaming.
|
||||
Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py
|
||||
"""
|
||||
# Build headers
|
||||
content_type_header = get_headers(":content-type", "application/octet-stream")
|
||||
event_type_header = get_headers(":event-type", "AudioEvent")
|
||||
message_type_header = get_headers(":message-type", "event")
|
||||
|
||||
headers = bytearray()
|
||||
headers.extend(content_type_header)
|
||||
headers.extend(event_type_header)
|
||||
headers.extend(message_type_header)
|
||||
|
||||
# Calculate total byte length and headers byte length
|
||||
# 16 accounts for 8 byte prelude, 2x 4 byte CRCs
|
||||
total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16)
|
||||
headers_byte_length = struct.pack(">I", len(headers))
|
||||
|
||||
# Build the prelude
|
||||
prelude = bytearray([0] * 8)
|
||||
prelude[:4] = total_byte_length
|
||||
prelude[4:] = headers_byte_length
|
||||
|
||||
# Calculate checksum for prelude
|
||||
prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF)
|
||||
|
||||
# Construct the message
|
||||
message_as_list = bytearray()
|
||||
message_as_list.extend(prelude)
|
||||
message_as_list.extend(prelude_crc)
|
||||
message_as_list.extend(headers)
|
||||
message_as_list.extend(payload)
|
||||
|
||||
# Calculate checksum for message
|
||||
message = bytes(message_as_list)
|
||||
message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF)
|
||||
|
||||
# Add message checksum
|
||||
message_as_list.extend(message_crc)
|
||||
|
||||
return bytes(message_as_list)
|
||||
|
||||
|
||||
def decode_event(message):
|
||||
# Extract the prelude, headers, payload and CRC
|
||||
prelude = message[:8]
|
||||
total_length, headers_length = struct.unpack(">II", prelude)
|
||||
prelude_crc = struct.unpack(">I", message[8:12])[0]
|
||||
headers = message[12 : 12 + headers_length]
|
||||
payload = message[12 + headers_length : -4]
|
||||
message_crc = struct.unpack(">I", message[-4:])[0]
|
||||
|
||||
# Check the CRCs
|
||||
assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed"
|
||||
assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed"
|
||||
|
||||
# Parse the headers
|
||||
headers_dict = {}
|
||||
while headers:
|
||||
name_len = headers[0]
|
||||
name = headers[1 : 1 + name_len].decode("utf-8")
|
||||
value_type = headers[1 + name_len]
|
||||
value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0]
|
||||
value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8")
|
||||
headers_dict[name] = value
|
||||
headers = headers[4 + name_len + value_len :]
|
||||
|
||||
return headers_dict, json.loads(payload)
|
||||
@@ -1 +1 @@
|
||||
-e ".[anthropic,google,langchain]"
|
||||
-e ".[anthropic,aws,google,langchain]"
|
||||
|
||||
@@ -40,6 +40,11 @@ from pipecat.services.anthropic.llm import (
|
||||
AnthropicLLMContext,
|
||||
AnthropicUserContextAggregator,
|
||||
)
|
||||
from pipecat.services.aws.llm import (
|
||||
AWSBedrockAssistantContextAggregator,
|
||||
AWSBedrockLLMContext,
|
||||
AWSBedrockUserContextAggregator,
|
||||
)
|
||||
from pipecat.services.google.llm import (
|
||||
GoogleAssistantContextAggregator,
|
||||
GoogleLLMContext,
|
||||
@@ -669,26 +674,6 @@ class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.Isola
|
||||
AGGREGATOR_CLASS = LLMUserContextAggregator
|
||||
|
||||
|
||||
#
|
||||
# OpenAI
|
||||
#
|
||||
|
||||
|
||||
class TestOpenAIUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIUserContextAggregator
|
||||
|
||||
|
||||
class TestOpenAIAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
|
||||
#
|
||||
# Anthropic
|
||||
#
|
||||
@@ -724,6 +709,43 @@ class TestAnthropicAssistantContextAggregator(
|
||||
assert context.messages[index]["content"][0]["content"] == json.dumps(content)
|
||||
|
||||
|
||||
#
|
||||
# AWS (Bedrock)
|
||||
#
|
||||
|
||||
|
||||
class TestAWSBedrockUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AWSBedrockLLMContext
|
||||
AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
messages = context.messages[content_index]
|
||||
assert messages["content"][index]["text"] == content
|
||||
|
||||
|
||||
class TestAWSBedrockAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AWSBedrockLLMContext
|
||||
AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
messages = context.messages[content_index]
|
||||
assert messages["content"][index]["text"] == content
|
||||
|
||||
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
|
||||
assert context.messages[index]["content"][0]["toolResult"]["content"][0][
|
||||
"text"
|
||||
] == json.dumps(content)
|
||||
|
||||
|
||||
#
|
||||
# Google
|
||||
#
|
||||
@@ -766,3 +788,23 @@ class TestGoogleAssistantContextAggregator(
|
||||
def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content)
|
||||
|
||||
|
||||
#
|
||||
# OpenAI
|
||||
#
|
||||
|
||||
|
||||
class TestOpenAIUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIUserContextAggregator
|
||||
|
||||
|
||||
class TestOpenAIAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
@@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionToolParam
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
@@ -174,3 +175,32 @@ class TestFunctionAdapters(unittest.TestCase):
|
||||
tools_def = self.tools_def
|
||||
tools_def.custom_tools = {AdapterType.GEMINI: [search_tool]}
|
||||
assert GeminiLLMAdapter().to_provider_tools_format(tools_def) == expected
|
||||
|
||||
def test_bedrock_adapter(self):
|
||||
"""Test AWS Bedrock adapter format transformation."""
|
||||
expected = [
|
||||
{
|
||||
"toolSpec": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given location",
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use.",
|
||||
},
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city, e.g. San Francisco",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
assert AWSBedrockLLMAdapter().to_provider_tools_format(self.tools_def) == expected
|
||||
|
||||
Reference in New Issue
Block a user