Compare commits

...

13 Commits

Author SHA1 Message Date
Adithya Suresh
21f9054af3 Added cache related info to metrics 2025-05-06 11:25:36 -07:00
Adithya Suresh
40670287af System param to be a list 2025-05-06 11:25:36 -07:00
Adithya Suresh
fc3f7cec51 Bug fix in content format 2025-05-06 11:25:36 -07:00
Adithya Suresh
1bfc53dd39 Remove model restriction 2025-05-06 11:25:36 -07:00
Adithya Suresh
6d6bffcabd Revert "Handle cache token counts being none"
This reverts commit 87934be342e438cc270ba04150d0189a870164f1.
2025-05-06 11:25:36 -07:00
Adithya Suresh
efcbe51a85 Revert "Add debug log"
This reverts commit 832b7b970e48e6488e1901d64d2ccfcc4a78c138.
2025-05-06 11:25:36 -07:00
Adithya Suresh
3f0000c82b Restructured STT and enabled prosody tags for generative Polly 2025-05-06 11:25:31 -07:00
Adithya Suresh
67c70f5588 Removed OpenAI based context formatting 2025-05-06 11:20:39 -07:00
Adithya Suresh
81f7cbfd09 Updated tools parsing logic 2025-05-06 11:20:39 -07:00
Adithya Suresh
3d1424d3cf Initial implementation 2025-05-06 11:20:37 -07:00
Adithya Suresh
5dbbb9021a Add debug log 2025-05-06 11:20:22 -07:00
Tico Ballagas
7cd8aaa46a Change example to use generative voices 2025-05-06 11:20:16 -07:00
Tico Ballagas
0b8f5ca6a5 Initial implementation of AWS Transcribe TTS 2025-05-06 11:19:25 -07:00
6 changed files with 1516 additions and 53 deletions

View File

@@ -5,7 +5,6 @@
#
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
@@ -14,13 +13,13 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.aws.tts import PollyTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transcriptions.language import Language
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
from pipecat.services.aws.llm import BedrockLLMService, BedrockLLMContext
from pipecat.services.aws.stt import TranscribeSTTService
from pipecat.services.aws.tts import PollyTTSService
load_dotenv(override=True)
@@ -37,26 +36,36 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
),
)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
stt = TranscribeSTTService()
tts = PollyTTSService(
api_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
region=os.getenv("AWS_REGION"),
voice_id="Amy",
params=PollyTTSService.InputParams(engine="neural", language="en-GB", rate="1.05"),
region="us-west-2", # only specific regions support generative TTS
voice_id="Joanna",
params=PollyTTSService.InputParams(
engine="generative",
language=Language.EN_US,
rate="1.1"
),
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
llm = BedrockLLMService(
aws_region="us-west-2",
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
params=BedrockLLMService.InputParams(
temperature=0.8,
latency="optimized"
)
)
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
{
"role": "system",
"content": [{"text": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way."}],
},
]
)
context = OpenAILLMContext(messages)
context = BedrockLLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
@@ -68,8 +77,8 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
]
)
task = PipelineTask(
pipeline,
@@ -85,16 +94,12 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
messages.append({"role": "user", "content": [{"text": "Please introduce yourself to the user."}]})
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
@transport.event_handler("on_client_closed")
async def on_client_closed(transport, client):
logger.info(f"Client closed connection")
await task.cancel()
runner = PipelineRunner(handle_sigint=False)

View File

@@ -0,0 +1,38 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Any, Dict, List, Union
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
class BedrockLLMAdapter(BaseLLMAdapter):
@staticmethod
def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]:
return {
"toolSpec": {
"name": function.name,
"description": function.description,
"inputSchema": {
"json": {
"type": "object",
"properties": function.properties,
"required": function.required,
},
}
}
}
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
"""Converts function schemas to Bedrock's function-calling format.
:return: Bedrock formatted function call definition.
"""
functions_schema = tools_schema.standard_tools
return [self._to_bedrock_function_format(func) for func in functions_schema]

View File

@@ -250,24 +250,14 @@ class AnthropicLLMService(LLMService):
if hasattr(event.message.usage, "output_tokens")
else 0
)
cache_creation_input_tokens += (
event.message.usage.cache_creation_input_tokens
if (
hasattr(event.message.usage, "cache_creation_input_tokens")
and event.message.usage.cache_creation_input_tokens is not None
if hasattr(event.message.usage, "cache_creation_input_tokens"):
cache_creation_input_tokens += (
event.message.usage.cache_creation_input_tokens
)
else 0
)
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
cache_read_input_tokens += (
event.message.usage.cache_read_input_tokens
if (
hasattr(event.message.usage, "cache_read_input_tokens")
and event.message.usage.cache_read_input_tokens is not None
)
else 0
)
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
if hasattr(event.message.usage, "cache_read_input_tokens"):
cache_read_input_tokens += event.message.usage.cache_read_input_tokens
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
total_input_tokens = (
prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
)

View File

@@ -0,0 +1,796 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import base64
import copy
import io
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Union
import boto3
from botocore.config import Config
import httpx
from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
from pipecat.frames.frames import (
Frame,
FunctionCallCancelFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMTextFrame,
LLMUpdateSettingsFrame,
UserImageRawFrame,
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
@dataclass
class BedrockContextAggregatorPair:
_user: "BedrockUserContextAggregator"
_assistant: "BedrockAssistantContextAggregator"
def user(self) -> "BedrockUserContextAggregator":
return self._user
def assistant(self) -> "BedrockAssistantContextAggregator":
return self._assistant
class BedrockLLMContext(OpenAILLMContext):
def __init__(
self,
messages: Optional[List[dict]] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[dict] = None,
*,
system: Optional[str] = None,
):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self.system = system
@staticmethod
def upgrade_to_bedrock(obj: OpenAILLMContext) -> "BedrockLLMContext":
logger.debug(f"Upgrading to Bedrock: {obj}")
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, BedrockLLMContext):
obj.__class__ = BedrockLLMContext
obj._restructure_from_openai_messages()
else:
obj._restructure_from_bedrock_messages()
return obj
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
logger.debug("from_openai_context called")
self = cls(
messages=openai_context.messages,
tools=openai_context.tools,
tool_choice=openai_context.tool_choice,
)
self.set_llm_adapter(openai_context.get_llm_adapter())
self._restructure_from_openai_messages()
return self
@classmethod
def from_messages(cls, messages: List[dict]) -> "BedrockLLMContext":
self = cls(messages=messages)
# self._restructure_from_openai_messages()
return self
@classmethod
def from_image_frame(cls, frame: VisionImageRawFrame) -> "BedrockLLMContext":
context = cls()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
)
return context
def set_messages(self, messages: List):
self._messages[:] = messages
# self._restructure_from_openai_messages()
# convert a message in Bedrock format into one or more messages in OpenAI format
def to_standard_messages(self, obj):
"""Convert Bedrock message format to standard structured format.
Handles text content and function calls for both user and assistant messages.
Args:
obj: Message in Bedrock format:
{
"role": "user/assistant",
"content": [{"text": str} | {"toolUse": {...}} | {"toolResult": {...}}]
}
Returns:
List of messages in standard format:
[
{
"role": "user/assistant/tool",
"content": [{"type": "text", "text": str}]
}
]
"""
role = obj.get("role")
content = obj.get("content")
if role == "assistant":
if isinstance(content, str):
return [{"role": role, "content": [{"type": "text", "text": content}]}]
elif isinstance(content, list):
text_items = []
tool_items = []
for item in content:
if "text" in item:
text_items.append({"type": "text", "text": item["text"]})
elif "toolUse" in item:
tool_use = item["toolUse"]
tool_items.append(
{
"type": "function",
"id": tool_use["toolUseId"],
"function": {
"name": tool_use["name"],
"arguments": json.dumps(tool_use["input"]),
},
}
)
messages = []
if text_items:
messages.append({"role": role, "content": text_items})
if tool_items:
messages.append({"role": role, "tool_calls": tool_items})
return messages
elif role == "user":
if isinstance(content, str):
return [{"role": role, "content": [{"type": "text", "text": content}]}]
elif isinstance(content, list):
text_items = []
tool_items = []
for item in content:
if "text" in item:
text_items.append({"type": "text", "text": item["text"]})
elif "toolResult" in item:
tool_result = item["toolResult"]
# Extract content from toolResult
result_content = ""
if isinstance(tool_result["content"], list):
for content_item in tool_result["content"]:
if "text" in content_item:
result_content = content_item["text"]
elif "json" in content_item:
result_content = json.dumps(content_item["json"])
else:
result_content = tool_result["content"]
tool_items.append(
{
"role": "tool",
"tool_call_id": tool_result["toolUseId"],
"content": result_content,
}
)
messages = []
if text_items:
messages.append({"role": role, "content": text_items})
messages.extend(tool_items)
return messages
def from_standard_message(self, message):
"""Convert standard format message to Bedrock format.
Handles conversion of text content, tool calls, and tool results.
Empty text content is converted to "(empty)".
Args:
message: Message in standard format:
{
"role": "user/assistant/tool",
"content": str | [{"type": "text", ...}],
"tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}]
}
Returns:
Message in Bedrock format:
{
"role": "user/assistant",
"content": [
{"text": str} |
{"toolUse": {"toolUseId": str, "name": str, "input": dict}} |
{"toolResult": {"toolUseId": str, "content": [...], "status": str}}
]
}
"""
if message["role"] == "tool":
# Try to parse the content as JSON if it looks like JSON
try:
if message["content"].strip().startswith('{') and message["content"].strip().endswith('}'):
content_json = json.loads(message["content"])
tool_result_content = [{"json": content_json}]
else:
tool_result_content = [{"text": message["content"]}]
except:
tool_result_content = [{"text": message["content"]}]
return {
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": message["tool_call_id"],
"content": tool_result_content
},
},
],
}
if message.get("tool_calls"):
tc = message["tool_calls"]
ret = {"role": "assistant", "content": []}
for tool_call in tc:
function = tool_call["function"]
arguments = json.loads(function["arguments"])
new_tool_use = {
"toolUse": {
"toolUseId": tool_call["id"],
"name": function["name"],
"input": arguments,
}
}
ret["content"].append(new_tool_use)
return ret
# Handle text content
content = message.get("content")
if isinstance(content, str):
if content == "":
return {"role": message["role"], "content": [{"text": "(empty)"}]}
else:
return {"role": message["role"], "content": [{"text": content}]}
elif isinstance(content, list):
new_content = []
for item in content:
if item.get("type", "") == "text":
text_content = item["text"] if item["text"] != "" else "(empty)"
new_content.append({"text": text_content})
return {"role": message["role"], "content": new_content}
return message
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Image should be the first content block in the message
content = [
{
"type": "image",
"format": "jpeg",
"source": {
"bytes": encoded_image
}
}
]
if text:
content.append({"text": text})
self.add_message({"role": "user", "content": content})
def add_message(self, message):
try:
if self.messages:
# Bedrock requires that roles alternate. If this message's role is the same as the
# last message, we should add this message's content to the last message.
if self.messages[-1]["role"] == message["role"]:
# if the last message has just a content string, convert it to a list
# in the proper format
if isinstance(self.messages[-1]["content"], str):
self.messages[-1]["content"] = [
{"text": self.messages[-1]["content"]}
]
# if this message has just a content string, convert it to a list
# in the proper format
if isinstance(message["content"], str):
message["content"] = [{"text": message["content"]}]
# append the content of this message to the last message
self.messages[-1]["content"].extend(message["content"])
else:
self.messages.append(message)
else:
self.messages.append(message)
except Exception as e:
logger.error(f"Error adding message: {e}")
def _restructure_from_bedrock_messages(self):
"""Restructure messages in Bedrock format by handling system messages,
merging consecutive messages with the same role, and ensuring proper content formatting.
"""
# Handle system message if present at the beginning
logger.debug(f"_restructure_from_bedrock_messages: {self.messages}")
if self.messages and self.messages[0]["role"] == "system":
if len(self.messages) == 1:
self.messages[0]["role"] = "user"
else:
system_content = self.messages.pop(0)["content"]
if isinstance(system_content, str):
system_content = [{"text": system_content}]
if self.system:
if isinstance(self.system, str):
self.system = [{"text": self.system}]
self.system.extend(system_content)
else:
self.system = system_content
# Ensure content is properly formatted
for msg in self.messages:
if isinstance(msg["content"], str):
msg["content"] = [{"text": msg["content"]}]
elif not msg["content"]:
msg["content"] = [{"text": "(empty)"}]
elif isinstance(msg["content"], list):
for idx, item in enumerate(msg["content"]):
if isinstance(item, dict) and "text" in item and item["text"] == "":
item["text"] = "(empty)"
elif isinstance(item, str) and item == "":
msg["content"][idx] = {"text": "(empty)"}
# Merge consecutive messages with the same role
merged_messages = []
for msg in self.messages:
if merged_messages and merged_messages[-1]["role"] == msg["role"]:
merged_messages[-1]["content"].extend(msg["content"])
else:
merged_messages.append(msg)
self.messages.clear()
self.messages.extend(merged_messages)
def _restructure_from_openai_messages(self):
logger.debug(f"_restructure_from_openai_messages: {self.messages}")
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
try:
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
except Exception as e:
logger.error(f"Error mapping messages: {e}")
# See if we should pull the system message out of our context.messages list. (For
# compatibility with Open AI messages format.)
if self.messages and self.messages[0]["role"] == "system":
if len(self.messages) == 1:
# If we have only have a system message in the list, all we can really do
# without introducing too much magic is change the role to "user".
self.messages[0]["role"] = "user"
else:
# If we have more than one message, we'll pull the system message out of the
# list.
self.system = self.messages[0]["content"]
self.messages.pop(0)
# Merge consecutive messages with the same role.
i = 0
while i < len(self.messages) - 1:
current_message = self.messages[i]
next_message = self.messages[i + 1]
if current_message["role"] == next_message["role"]:
# Convert content to list of dictionaries if it's a string
if isinstance(current_message["content"], str):
current_message["content"] = [
{"type": "text", "text": current_message["content"]}
]
if isinstance(next_message["content"], str):
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
# Concatenate the content
current_message["content"].extend(next_message["content"])
# Remove the next message from the list
self.messages.pop(i + 1)
else:
i += 1
# Avoid empty content in messages
for message in self.messages:
if isinstance(message["content"], str) and message["content"] == "":
message["content"] = "(empty)"
elif isinstance(message["content"], list) and len(message["content"]) == 0:
message["content"] = [{"type": "text", "text": "(empty)"}]
def get_messages_for_persistent_storage(self):
messages = super().get_messages_for_persistent_storage()
if self.system:
messages.insert(0, {"role": "system", "content": self.system})
return messages
def get_messages_for_logging(self) -> str:
msgs = []
for message in self.messages:
msg = copy.deepcopy(message)
if "content" in msg:
if isinstance(msg["content"], list):
for item in msg["content"]:
if item.get("image"):
item["source"]["bytes"] = "..."
msgs.append(msg)
return json.dumps(msgs)
class BedrockUserContextAggregator(LLMUserContextAggregator):
pass
class BedrockAssistantContextAggregator(LLMAssistantContextAggregator):
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
# Format tool use according to Bedrock API
self._context.add_message(
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": frame.tool_call_id,
"name": frame.function_name,
"input": frame.arguments if frame.arguments else {}
}
}
],
}
)
self._context.add_message(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": frame.tool_call_id,
"content": [
{
"text": "IN_PROGRESS"
}
],
}
}
],
}
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
else:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "COMPLETED"
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
async def _update_function_call_result(
self, function_name: str, tool_call_id: str, result: Any
):
for message in self._context.messages:
if message["role"] == "user":
for content in message["content"]:
if (
isinstance(content, dict)
and content.get("toolResult")
and content["toolResult"]["toolUseId"] == tool_call_id
):
content["toolResult"]["content"] = [{"text": result}]
async def handle_user_image_frame(self, frame: UserImageRawFrame):
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
text=frame.request.context,
)
class BedrockLLMService(LLMService):
"""This class implements inference with AWS Bedrock models including Amazon Nova and Anthropic Claude.
Requires AWS credentials to be configured in the environment or through boto3 configuration.
"""
class InputParams(BaseModel):
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
temperature: Optional[float] = Field(default_factory=lambda: 0.7, ge=0.0, le=1.0)
top_p: Optional[float] = Field(default_factory=lambda: 0.999, ge=0.0, le=1.0)
stop_sequences: Optional[List[str]] = Field(default_factory=lambda: [])
latency: Optional[str] = Field(default_factory=lambda: "standard")
additional_model_request_fields: Optional[Dict[str, Any]] = Field(default_factory=dict)
def __init__(
self,
*,
aws_access_key: Optional[str] = None,
aws_secret_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region: str = "us-east-1",
model: str,
params: InputParams = InputParams(),
client_config: Optional[Config] = None,
**kwargs,
):
super().__init__(**kwargs)
# Initialize the Bedrock client
if not client_config:
client_config = Config(
connect_timeout=300, # 5 minutes
read_timeout=300, # 5 minutes
retries={'max_attempts': 3}
)
session = boto3.Session(
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
region_name=aws_region
)
self._client = session.client(
service_name='bedrock-runtime',
config=client_config
)
self.set_model_name(model)
self._settings = {
"max_tokens": params.max_tokens,
"temperature": params.temperature,
"top_p": params.top_p,
"latency": params.latency,
"additional_model_request_fields": params.additional_model_request_fields if isinstance(params.additional_model_request_fields, dict) else {},
}
logger.info(f"Using AWS Bedrock model: {model}")
def can_generate_metrics(self) -> bool:
return True
def create_context_aggregator(
self,
context: BedrockLLMContext,
*,
user_kwargs: Mapping[str, Any] = {},
assistant_kwargs: Mapping[str, Any] = {},
) -> BedrockContextAggregatorPair:
"""Create an instance of BedrockContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
Args:
context (OpenAILLMContext): The LLM context.
user_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the user context aggregator constructor. Defaults
to an empty mapping.
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
arguments for the assistant context aggregator
constructor. Defaults to an empty mapping.
Returns:
BedrockContextAggregatorPair: A pair of context aggregators, one
for the user and one for the assistant, encapsulated in an
BedrockContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
if isinstance(context, OpenAILLMContext) and not isinstance(context, BedrockLLMContext):
context = BedrockLLMContext.from_openai_context(context)
user = BedrockUserContextAggregator(context, **user_kwargs)
assistant = BedrockAssistantContextAggregator(context, **assistant_kwargs)
return BedrockContextAggregatorPair(_user=user, _assistant=assistant)
async def _process_context(self, context: BedrockLLMContext):
# Usage tracking
prompt_tokens = 0
completion_tokens = 0
completion_tokens_estimate = 0
cache_read_input_tokens = 0
cache_creation_input_tokens = 0
use_completion_tokens_estimate = False
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
# logger.debug(
# f"{self}: Generating chat with Bedrock model {self.model_name} | [{context.get_messages_for_logging()}]"
# )
await self.start_ttfb_metrics()
# Set up inference config
inference_config = {
"maxTokens": self._settings["max_tokens"],
"temperature": self._settings["temperature"],
"topP": self._settings["top_p"],
}
# Prepare request parameters
request_params = {
"modelId": self.model_name,
"messages": context.messages,
"inferenceConfig": inference_config,
"additionalModelRequestFields": self._settings["additional_model_request_fields"]
}
# Add system message
request_params["system"] = context.system
# Add tools if present
if context.tools:
tool_config = {
"tools": context.tools
}
# Add tool_choice if specified
if context.tool_choice:
if context.tool_choice == "auto":
tool_config["toolChoice"] = {"auto": {}}
elif context.tool_choice == "none":
# Skip adding toolChoice for "none"
pass
elif isinstance(context.tool_choice, dict) and "function" in context.tool_choice:
tool_config["toolChoice"] = {
"tool": {
"name": context.tool_choice["function"]["name"]
}
}
request_params["toolConfig"] = tool_config
# Add performance config if latency is specified
if self._settings["latency"] in ["standard", "optimized"]:
request_params["performanceConfig"] = {
"latency": self._settings["latency"]
}
logger.debug(f"Calling Bedrock model with: {request_params}")
# Call Bedrock with streaming
response = self._client.converse_stream(**request_params)
await self.stop_ttfb_metrics()
# Process the streaming response
tool_use_block = None
json_accumulator = ""
for event in response["stream"]:
# Handle text content
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
await self.push_frame(LLMTextFrame(delta["text"]))
completion_tokens_estimate += self._estimate_tokens(delta["text"])
elif "toolUse" in delta and "input" in delta["toolUse"]:
# Handle partial JSON for tool use
json_accumulator += delta["toolUse"]["input"]
completion_tokens_estimate += self._estimate_tokens(delta["toolUse"]["input"])
# Handle tool use start
elif "contentBlockStart" in event:
content_block_start = event["contentBlockStart"]['start']
if "toolUse" in content_block_start:
tool_use_block = {
"id": content_block_start["toolUse"].get("toolUseId", ""),
"name": content_block_start["toolUse"].get("name", "")
}
json_accumulator = ""
# Handle message completion with tool use
elif "messageStop" in event and "stopReason" in event["messageStop"]:
if event["messageStop"]["stopReason"] == "tool_use" and tool_use_block:
try:
arguments = json.loads(json_accumulator) if json_accumulator else {}
await self.call_function(
context=context,
tool_call_id=tool_use_block["id"],
function_name=tool_use_block["name"],
arguments=arguments,
)
except json.JSONDecodeError:
logger.error(f"Failed to parse tool arguments: {json_accumulator}")
# Handle usage metrics if available
if "metadata" in event and "usage" in event["metadata"]:
usage = event["metadata"]["usage"]
prompt_tokens += usage.get("inputTokens", 0)
completion_tokens += usage.get("outputTokens", 0)
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
except asyncio.CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
# token estimate. The reraise the exception so all the processors running in this task
# also get cancelled.
use_completion_tokens_estimate = True
raise
except httpx.TimeoutException:
await self._call_event_handler("on_completion_timeout")
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
comp_tokens = (
completion_tokens
if not use_completion_tokens_estimate
else completion_tokens_estimate
)
await self._report_usage_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=comp_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_input_tokens=cache_creation_input_tokens
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context = BedrockLLMContext.upgrade_to_bedrock(frame.context)
elif isinstance(frame, LLMMessagesFrame):
context = BedrockLLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
# This is only useful in very simple pipelines because it creates
# a new context. Generally we want a context manager to catch
# UserImageRawFrames coming through the pipeline and add them
# to the context.
context = BedrockLLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)
if context:
await self._process_context(context)
def _estimate_tokens(self, text: str) -> int:
return int(len(re.split(r"[^\w]+", text)) * 1.3)
async def _report_usage_metrics(
self,
prompt_tokens: int,
completion_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int
):
if prompt_tokens or completion_tokens:
tokens = LLMTokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_input_tokens=cache_creation_input_tokens
)
await self.start_llm_usage_metrics(tokens)

View File

@@ -0,0 +1,600 @@
import asyncio
from typing import AsyncGenerator, Optional, Dict
import os
import datetime
from urllib.parse import urlencode
import json
import struct
import urllib.parse
import hashlib
import hmac
import random
import string
import binascii
from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TranscriptionFrame,
InterimTranscriptionFrame,
StartFrame
)
from pipecat.services.ai_services import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
try:
import boto3
from botocore.exceptions import BotoCoreError, ClientError
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
)
raise Exception(f"Missing module: {e}")
def get_presigned_url(
*,
region: str,
credentials: Dict[str, Optional[str]],
language_code: str,
media_encoding: str = "pcm",
sample_rate: int = 16000,
number_of_channels: int = 1,
enable_partial_results_stabilization: bool = True,
partial_results_stability: str = "high",
vocabulary_name: Optional[str] = None,
vocabulary_filter_name: Optional[str] = None,
show_speaker_label: bool = False,
enable_channel_identification: bool = False,
) -> str:
"""Create a presigned URL for AWS Transcribe streaming."""
access_key = credentials.get("access_key")
secret_key = credentials.get("secret_key")
session_token = credentials.get("session_token")
if not access_key or not secret_key:
raise ValueError("AWS credentials are required")
# Initialize the URL generator
url_generator = AWSTranscribePresignedURL(
access_key=access_key, secret_key=secret_key, session_token=session_token, region=region
)
# Get the presigned URL
return url_generator.get_request_url(
sample_rate=sample_rate,
language_code=language_code,
media_encoding=media_encoding,
vocabulary_name=vocabulary_name,
vocabulary_filter_name=vocabulary_filter_name,
show_speaker_label=show_speaker_label,
enable_channel_identification=enable_channel_identification,
number_of_channels=number_of_channels,
enable_partial_results_stabilization=enable_partial_results_stabilization,
partial_results_stability=partial_results_stability,
)
class AWSTranscribePresignedURL:
def __init__(
self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1"
):
self.access_key = access_key
self.secret_key = secret_key
self.session_token = session_token
self.method = "GET"
self.service = "transcribe"
self.region = region
self.endpoint = ""
self.host = ""
self.amz_date = ""
self.datestamp = ""
self.canonical_uri = "/stream-transcription-websocket"
self.canonical_headers = ""
self.signed_headers = "host"
self.algorithm = "AWS4-HMAC-SHA256"
self.credential_scope = ""
self.canonical_querystring = ""
self.payload_hash = ""
self.canonical_request = ""
self.string_to_sign = ""
self.signature = ""
self.request_url = ""
def get_request_url(
self,
sample_rate: int,
language_code: str = "",
media_encoding: str = "pcm",
vocabulary_name: str = "",
vocabulary_filter_name: str = "",
show_speaker_label: bool = False,
enable_channel_identification: bool = False,
number_of_channels: int = 1,
enable_partial_results_stabilization: bool = False,
partial_results_stability: str = "",
) -> str:
self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443"
self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443"
now = datetime.datetime.utcnow()
self.amz_date = now.strftime("%Y%m%dT%H%M%SZ")
self.datestamp = now.strftime("%Y%m%d")
self.canonical_headers = f"host:{self.host}\n"
self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request"
# Create canonical querystring
self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm
self.canonical_querystring += (
"&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope
)
self.canonical_querystring += "&X-Amz-Date=" + self.amz_date
self.canonical_querystring += "&X-Amz-Expires=300"
if self.session_token:
self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote(
self.session_token, safe=""
)
self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers
if enable_channel_identification:
self.canonical_querystring += "&enable-channel-identification=true"
if enable_partial_results_stabilization:
self.canonical_querystring += "&enable-partial-results-stabilization=true"
if language_code:
self.canonical_querystring += "&language-code=" + language_code
if media_encoding:
self.canonical_querystring += "&media-encoding=" + media_encoding
if number_of_channels > 1:
self.canonical_querystring += "&number-of-channels=" + str(number_of_channels)
if partial_results_stability:
self.canonical_querystring += "&partial-results-stability=" + partial_results_stability
if sample_rate:
self.canonical_querystring += "&sample-rate=" + str(sample_rate)
if show_speaker_label:
self.canonical_querystring += "&show-speaker-label=true"
if vocabulary_filter_name:
self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name
if vocabulary_name:
self.canonical_querystring += "&vocabulary-name=" + vocabulary_name
# Create payload hash
self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest()
# Create canonical request
self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}"
# Create string to sign
credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request"
string_to_sign = (
f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n"
+ hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest()
)
# Calculate signature
k_date = hmac.new(
f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256
).digest()
k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest()
k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest()
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
self.signature = hmac.new(
k_signing, string_to_sign.encode("utf-8"), hashlib.sha256
).hexdigest()
# Add signature to query string
self.canonical_querystring += "&X-Amz-Signature=" + self.signature
# Create request URL
self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring
return self.request_url
def get_headers(header_name: str, header_value: str) -> bytearray:
"""Build a header following AWS event stream format."""
name = header_name.encode("utf-8")
name_byte_length = bytes([len(name)])
value_type = bytes([7]) # 7 represents a string
value = header_value.encode("utf-8")
value_byte_length = struct.pack(">H", len(value))
# Construct the header
header_list = bytearray()
header_list.extend(name_byte_length)
header_list.extend(name)
header_list.extend(value_type)
header_list.extend(value_byte_length)
header_list.extend(value)
return header_list
def build_event_message(payload: bytes) -> bytes:
"""
Build an event message for AWS Transcribe streaming.
Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py
"""
# Build headers
content_type_header = get_headers(":content-type", "application/octet-stream")
event_type_header = get_headers(":event-type", "AudioEvent")
message_type_header = get_headers(":message-type", "event")
headers = bytearray()
headers.extend(content_type_header)
headers.extend(event_type_header)
headers.extend(message_type_header)
# Calculate total byte length and headers byte length
# 16 accounts for 8 byte prelude, 2x 4 byte CRCs
total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16)
headers_byte_length = struct.pack(">I", len(headers))
# Build the prelude
prelude = bytearray([0] * 8)
prelude[:4] = total_byte_length
prelude[4:] = headers_byte_length
# Calculate checksum for prelude
prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF)
# Construct the message
message_as_list = bytearray()
message_as_list.extend(prelude)
message_as_list.extend(prelude_crc)
message_as_list.extend(headers)
message_as_list.extend(payload)
# Calculate checksum for message
message = bytes(message_as_list)
message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF)
# Add message checksum
message_as_list.extend(message_crc)
return bytes(message_as_list)
def decode_event(message):
# Extract the prelude, headers, payload and CRC
prelude = message[:8]
total_length, headers_length = struct.unpack(">II", prelude)
prelude_crc = struct.unpack(">I", message[8:12])[0]
headers = message[12 : 12 + headers_length]
payload = message[12 + headers_length : -4]
message_crc = struct.unpack(">I", message[-4:])[0]
# Check the CRCs
assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed"
assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed"
# Parse the headers
headers_dict = {}
while headers:
name_len = headers[0]
name = headers[1 : 1 + name_len].decode("utf-8")
value_type = headers[1 + name_len]
value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0]
value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8")
headers_dict[name] = value
headers = headers[4 + name_len + value_len :]
return headers_dict, json.loads(payload)
class TranscribeSTTService(STTService):
def __init__(
self,
*,
api_key: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_session_token: Optional[str] = None,
region: Optional[str] = "us-east-1",
sample_rate: int = 16000,
language: Language = Language.EN,
**kwargs,
):
super().__init__(**kwargs)
self._settings = {
"sample_rate": sample_rate,
"language": language,
"media_encoding": "linear16", # AWS expects raw PCM
"number_of_channels": 1,
"show_speaker_label": False,
"enable_channel_identification": False,
}
# Validate sample rate - AWS Transcribe only supports 8000 Hz or 16000 Hz
if sample_rate not in [8000, 16000]:
logger.warning(
f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. Converting from {sample_rate} Hz to 16000 Hz."
)
self._settings["sample_rate"] = 16000
self._credentials = {
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
"region": region or os.getenv("AWS_REGION", "us-east-1"),
}
self._ws_client = None
self._connection_lock = asyncio.Lock()
self._connecting = False
self._receive_task = None
def get_service_encoding(self, encoding: str) -> str:
"""Convert internal encoding format to AWS Transcribe format."""
encoding_map = {
"linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM
}
return encoding_map.get(encoding, encoding)
async def start(self, frame: StartFrame):
"""Initialize the connection when the service starts."""
await super().start(frame)
logger.info("Starting AWS Transcribe service...")
retry_count = 0
max_retries = 3
while retry_count < max_retries:
try:
await self._connect()
if self._ws_client and self._ws_client.open:
logger.info("Successfully established WebSocket connection")
return
logger.warning("WebSocket connection not established after connect")
except Exception as e:
logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}")
retry_count += 1
if retry_count < max_retries:
await asyncio.sleep(1) # Wait before retrying
raise RuntimeError("Failed to establish WebSocket connection after multiple attempts")
async def run_stt(self, frame: Frame) -> AsyncGenerator[Frame, None]:
"""Process audio data and send to AWS Transcribe"""
try:
# Skip if no speech detected
if hasattr(frame, "is_speech") and not frame.is_speech:
logger.debug("Skipping non-speech frame")
return
# Ensure WebSocket is connected
if not self._ws_client or not self._ws_client.open:
logger.info("WebSocket not connected, attempting to reconnect...")
try:
await self._connect()
except Exception as e:
logger.error(f"Failed to reconnect: {e}")
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
return
# Get the audio data - if frame is bytes, use directly, otherwise get audio attribute
audio_data = frame if isinstance(frame, bytes) else frame.audio
# Format the audio data according to AWS event stream format
event_message = build_event_message(audio_data)
# logger.debug(f"Sending audio chunk of size {len(audio_data)} bytes")
# Send the formatted event message
try:
await self._ws_client.send(event_message)
# Start metrics after first chunk sent
await self.start_processing_metrics()
await self.start_ttfb_metrics()
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"Connection closed while sending: {e}")
await self._disconnect()
# Don't yield error here - we'll retry on next frame
except Exception as e:
logger.error(f"Error sending audio: {e}")
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
await self._disconnect()
except Exception as e:
logger.error(f"Error in run_stt: {e}")
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
await self._disconnect()
async def _connect(self):
"""Connect to AWS Transcribe with connection state management."""
if (
self._ws_client
and self._ws_client.open
and self._receive_task
and not self._receive_task.done()
):
logger.debug("Already connected")
return
async with self._connection_lock:
if self._connecting:
logger.debug("Connection already in progress")
return
try:
self._connecting = True
logger.debug("Starting connection process...")
if self._ws_client:
await self._disconnect()
language_code = self.language_to_service_language(
Language(self._settings["language"])
)
if not language_code:
raise ValueError(f"Unsupported language: {self._settings['language']}")
# Generate random websocket key
websocket_key = "".join(
random.choices(
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
)
)
# Add required headers
extra_headers = {
"Origin": "https://localhost",
"Sec-WebSocket-Key": websocket_key,
"Sec-WebSocket-Version": "13",
"Connection": "keep-alive",
}
# Get presigned URL
presigned_url = get_presigned_url(
region=self._credentials["region"],
credentials={
"access_key": self._credentials["aws_access_key_id"],
"secret_key": self._credentials["aws_secret_access_key"],
"session_token": self._credentials["aws_session_token"],
},
language_code=language_code,
media_encoding=self.get_service_encoding(
self._settings["media_encoding"]
), # Convert to AWS format
sample_rate=self._settings["sample_rate"],
number_of_channels=self._settings["number_of_channels"],
enable_partial_results_stabilization=True,
partial_results_stability="high",
show_speaker_label=self._settings["show_speaker_label"],
enable_channel_identification=self._settings["enable_channel_identification"],
)
logger.debug(f"Connecting to WebSocket with URL: {presigned_url[:100]}...")
# Connect with the required headers and settings
self._ws_client = await websockets.connect(
presigned_url,
extra_headers=extra_headers,
subprotocols=["mqtt"],
ping_interval=None,
ping_timeout=None,
compression=None,
)
logger.debug("WebSocket connected, starting receive task...")
# Start receive task
self._receive_task = asyncio.create_task(self._receive_loop())
logger.info("Successfully connected to AWS Transcribe")
except Exception as e:
logger.error(f"Failed to connect to AWS Transcribe: {e}")
await self._disconnect()
raise
finally:
self._connecting = False
async def _disconnect(self):
"""Disconnect from AWS Transcribe."""
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
self._receive_task = None
if self._ws_client:
try:
if self._ws_client.open:
# Send end-stream message
end_stream = {"message-type": "event", "event": "end"}
await self._ws_client.send(json.dumps(end_stream))
await self._ws_client.close()
except Exception as e:
logger.warning(f"Error closing WebSocket connection: {e}")
finally:
self._ws_client = None
def language_to_service_language(self, language: Language) -> str | None:
"""Convert internal language enum to AWS Transcribe language code."""
language_map = {
Language.EN: "en-US",
Language.ES: "es-US",
Language.FR: "fr-FR",
Language.DE: "de-DE",
Language.IT: "it-IT",
Language.PT: "pt-BR",
Language.JA: "ja-JP",
Language.KO: "ko-KR",
Language.ZH: "zh-CN",
}
return language_map.get(language)
async def _receive_loop(self):
"""Background task to receive and process messages from AWS Transcribe."""
try:
logger.debug("Receive loop started")
while True:
if not self._ws_client or not self._ws_client.open:
logger.warning("WebSocket closed in receive loop")
break
try:
response = await self._ws_client.recv()
headers, payload = decode_event(response)
# logger.debug(f"Received message type: {headers.get(':message-type')}")
if headers.get(":message-type") == "event":
# Process transcription results
results = payload.get("Transcript", {}).get("Results", [])
if results:
result = results[0]
alternatives = result.get("Alternatives", [])
if alternatives:
transcript = alternatives[0].get("Transcript", "")
is_final = not result.get("IsPartial", True)
if transcript:
await self.stop_ttfb_metrics()
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
"",
time_now_iso8601(),
self._settings["language"],
)
)
await self.stop_processing_metrics()
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
"",
time_now_iso8601(),
self._settings["language"],
)
)
elif headers.get(":message-type") == "exception":
error_msg = payload.get("Message", "Unknown error")
logger.error(f"Exception from AWS: {error_msg}")
await self.push_frame(
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
)
else:
logger.debug(f"Other message type received: {headers}")
logger.debug(f"Payload: {payload}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(
f"WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
)
break
except Exception as e:
logger.error(f"Error in receive loop: {e}")
break
except asyncio.CancelledError:
logger.debug("Receive loop cancelled")
except Exception as e:
logger.error(f"Unexpected error in receive loop: {e}")
finally:
logger.debug("Receive loop ended")

View File

@@ -6,6 +6,7 @@
import asyncio
from typing import AsyncGenerator, Optional
import os
from loguru import logger
from pydantic import BaseModel
@@ -16,9 +17,9 @@ from pipecat.frames.frames import (
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSStoppedFrame
)
from pipecat.services.tts_service import TTSService
from pipecat.services.ai_services import TTSService
from pipecat.transcriptions.language import Language
try:
@@ -27,7 +28,7 @@ try:
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Deepgram, you need to `pip install pipecat-ai[aws]`. Also, set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
)
raise Exception(f"Missing module: {e}")
@@ -151,6 +152,24 @@ class PollyTTSService(TTSService):
self.set_voice(voice_id)
# Get credentials from environment variables if not provided
self._credentials = {
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
"region": region or os.getenv("AWS_REGION", "us-east-1"),
}
# Validate that we have the required credentials
if (
not self._credentials["aws_access_key_id"]
or not self._credentials["aws_secret_access_key"]
):
raise ValueError(
"AWS credentials not found. Please provide them either through constructor parameters "
"or set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
)
def can_generate_metrics(self) -> bool:
return True
@@ -165,18 +184,18 @@ class PollyTTSService(TTSService):
prosody_attrs = []
# Prosody tags are only supported for standard and neural engines
if self._settings["engine"] != "generative":
if self._settings["rate"]:
prosody_attrs.append(f"rate='{self._settings['rate']}'")
if self._settings["engine"] == "standard":
if self._settings["pitch"]:
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
if self._settings["volume"]:
prosody_attrs.append(f"volume='{self._settings['volume']}'")
if self._settings["rate"]:
prosody_attrs.append(f"rate='{self._settings['rate']}'")
if self._settings["volume"]:
prosody_attrs.append(f"volume='{self._settings['volume']}'")
# logger.warning("Prosody tags are not supported for generative engine. Ignoring.")
if prosody_attrs:
if prosody_attrs:
ssml += f"<prosody {' '.join(prosody_attrs)}>"
else:
logger.warning("Prosody tags are not supported for generative engine. Ignoring.")
ssml += text
@@ -187,6 +206,8 @@ class PollyTTSService(TTSService):
ssml += "</speak>"
logger.debug(f"SSML: {ssml}")
return ssml
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
@@ -248,3 +269,16 @@ class PollyTTSService(TTSService):
finally:
yield TTSStoppedFrame()
class AWSTTSService(PollyTTSService):
def __init__(self, **kwargs):
super().__init__(**kwargs)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"'AWSTTSService' is deprecated, use 'PollyTTSService' instead.", DeprecationWarning
)