AWS: various cleanups (logs, imports...)
This commit is contained in:
@@ -13,7 +13,8 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.services.aws.llm import BedrockLLMContext, BedrockLLMService
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.aws.llm import BedrockLLMService
|
||||
from pipecat.services.aws.stt import TranscribeSTTService
|
||||
from pipecat.services.aws.tts import PollyTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -55,15 +56,11 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"text": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way."
|
||||
}
|
||||
],
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = BedrockLLMContext(messages)
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
@@ -92,14 +89,16 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{"role": "user", "content": [{"text": "Please introduce yourself to the user."}]}
|
||||
)
|
||||
messages.append({"role": "user", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
@@ -41,7 +41,7 @@ Website = "https://pipecat.ai"
|
||||
[project.optional-dependencies]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "assemblyai~=0.37.0" ]
|
||||
aws = [ "boto3~=1.37.16" ]
|
||||
aws = [ "boto3~=1.37.16", "websockets~=13.1" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ]
|
||||
cerebras = []
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
|
||||
@@ -8,6 +8,8 @@ import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts")
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.[llm,stt,tts]")
|
||||
|
||||
@@ -11,16 +11,12 @@ import io
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
@@ -36,7 +32,9 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
@@ -44,7 +42,18 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
try:
|
||||
import boto3
|
||||
import httpx
|
||||
from botocore.config import Config
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -564,10 +573,10 @@ class BedrockLLMService(LLMService):
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: BedrockLLMContext,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> BedrockContextAggregatorPair:
|
||||
"""Create an instance of BedrockContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
@@ -575,12 +584,10 @@ class BedrockLLMService(LLMService):
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
user_params (LLMUserAggregatorParams, optional): User aggregator
|
||||
parameters.
|
||||
assistant_params (LLMAssistantAggregatorParams, optional): User
|
||||
aggregator parameters.
|
||||
|
||||
Returns:
|
||||
BedrockContextAggregatorPair: A pair of context aggregators, one
|
||||
@@ -589,11 +596,11 @@ class BedrockLLMService(LLMService):
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
if isinstance(context, OpenAILLMContext) and not isinstance(context, BedrockLLMContext):
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = BedrockLLMContext.from_openai_context(context)
|
||||
|
||||
user = BedrockUserContextAggregator(context, **user_kwargs)
|
||||
assistant = BedrockAssistantContextAggregator(context, **assistant_kwargs)
|
||||
user = BedrockUserContextAggregator(context, params=user_params)
|
||||
assistant = BedrockAssistantContextAggregator(context, params=assistant_params)
|
||||
return BedrockContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def _process_context(self, context: BedrockLLMContext):
|
||||
|
||||
@@ -1,289 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional, Dict
|
||||
import os
|
||||
import datetime
|
||||
from urllib.parse import urlencode
|
||||
import json
|
||||
import struct
|
||||
import urllib.parse
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import binascii
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TranscriptionFrame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
|
||||
)
|
||||
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def get_presigned_url(
|
||||
*,
|
||||
region: str,
|
||||
credentials: Dict[str, Optional[str]],
|
||||
language_code: str,
|
||||
media_encoding: str = "pcm",
|
||||
sample_rate: int = 16000,
|
||||
number_of_channels: int = 1,
|
||||
enable_partial_results_stabilization: bool = True,
|
||||
partial_results_stability: str = "high",
|
||||
vocabulary_name: Optional[str] = None,
|
||||
vocabulary_filter_name: Optional[str] = None,
|
||||
show_speaker_label: bool = False,
|
||||
enable_channel_identification: bool = False,
|
||||
) -> str:
|
||||
"""Create a presigned URL for AWS Transcribe streaming."""
|
||||
access_key = credentials.get("access_key")
|
||||
secret_key = credentials.get("secret_key")
|
||||
session_token = credentials.get("session_token")
|
||||
|
||||
if not access_key or not secret_key:
|
||||
raise ValueError("AWS credentials are required")
|
||||
|
||||
# Initialize the URL generator
|
||||
url_generator = AWSTranscribePresignedURL(
|
||||
access_key=access_key, secret_key=secret_key, session_token=session_token, region=region
|
||||
)
|
||||
|
||||
# Get the presigned URL
|
||||
return url_generator.get_request_url(
|
||||
sample_rate=sample_rate,
|
||||
language_code=language_code,
|
||||
media_encoding=media_encoding,
|
||||
vocabulary_name=vocabulary_name,
|
||||
vocabulary_filter_name=vocabulary_filter_name,
|
||||
show_speaker_label=show_speaker_label,
|
||||
enable_channel_identification=enable_channel_identification,
|
||||
number_of_channels=number_of_channels,
|
||||
enable_partial_results_stabilization=enable_partial_results_stabilization,
|
||||
partial_results_stability=partial_results_stability,
|
||||
)
|
||||
|
||||
|
||||
class AWSTranscribePresignedURL:
|
||||
def __init__(
|
||||
self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1"
|
||||
):
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self.session_token = session_token
|
||||
self.method = "GET"
|
||||
self.service = "transcribe"
|
||||
self.region = region
|
||||
self.endpoint = ""
|
||||
self.host = ""
|
||||
self.amz_date = ""
|
||||
self.datestamp = ""
|
||||
self.canonical_uri = "/stream-transcription-websocket"
|
||||
self.canonical_headers = ""
|
||||
self.signed_headers = "host"
|
||||
self.algorithm = "AWS4-HMAC-SHA256"
|
||||
self.credential_scope = ""
|
||||
self.canonical_querystring = ""
|
||||
self.payload_hash = ""
|
||||
self.canonical_request = ""
|
||||
self.string_to_sign = ""
|
||||
self.signature = ""
|
||||
self.request_url = ""
|
||||
|
||||
def get_request_url(
|
||||
self,
|
||||
sample_rate: int,
|
||||
language_code: str = "",
|
||||
media_encoding: str = "pcm",
|
||||
vocabulary_name: str = "",
|
||||
vocabulary_filter_name: str = "",
|
||||
show_speaker_label: bool = False,
|
||||
enable_channel_identification: bool = False,
|
||||
number_of_channels: int = 1,
|
||||
enable_partial_results_stabilization: bool = False,
|
||||
partial_results_stability: str = "",
|
||||
) -> str:
|
||||
self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443"
|
||||
self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443"
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
self.amz_date = now.strftime("%Y%m%dT%H%M%SZ")
|
||||
self.datestamp = now.strftime("%Y%m%d")
|
||||
self.canonical_headers = f"host:{self.host}\n"
|
||||
self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request"
|
||||
|
||||
# Create canonical querystring
|
||||
self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm
|
||||
self.canonical_querystring += (
|
||||
"&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope
|
||||
)
|
||||
self.canonical_querystring += "&X-Amz-Date=" + self.amz_date
|
||||
self.canonical_querystring += "&X-Amz-Expires=300"
|
||||
if self.session_token:
|
||||
self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote(
|
||||
self.session_token, safe=""
|
||||
)
|
||||
self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers
|
||||
|
||||
if enable_channel_identification:
|
||||
self.canonical_querystring += "&enable-channel-identification=true"
|
||||
if enable_partial_results_stabilization:
|
||||
self.canonical_querystring += "&enable-partial-results-stabilization=true"
|
||||
if language_code:
|
||||
self.canonical_querystring += "&language-code=" + language_code
|
||||
if media_encoding:
|
||||
self.canonical_querystring += "&media-encoding=" + media_encoding
|
||||
if number_of_channels > 1:
|
||||
self.canonical_querystring += "&number-of-channels=" + str(number_of_channels)
|
||||
if partial_results_stability:
|
||||
self.canonical_querystring += "&partial-results-stability=" + partial_results_stability
|
||||
if sample_rate:
|
||||
self.canonical_querystring += "&sample-rate=" + str(sample_rate)
|
||||
if show_speaker_label:
|
||||
self.canonical_querystring += "&show-speaker-label=true"
|
||||
if vocabulary_filter_name:
|
||||
self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name
|
||||
if vocabulary_name:
|
||||
self.canonical_querystring += "&vocabulary-name=" + vocabulary_name
|
||||
|
||||
# Create payload hash
|
||||
self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest()
|
||||
|
||||
# Create canonical request
|
||||
self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}"
|
||||
|
||||
# Create string to sign
|
||||
credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request"
|
||||
string_to_sign = (
|
||||
f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n"
|
||||
+ hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest()
|
||||
)
|
||||
|
||||
# Calculate signature
|
||||
k_date = hmac.new(
|
||||
f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest()
|
||||
k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest()
|
||||
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
|
||||
self.signature = hmac.new(
|
||||
k_signing, string_to_sign.encode("utf-8"), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Add signature to query string
|
||||
self.canonical_querystring += "&X-Amz-Signature=" + self.signature
|
||||
|
||||
# Create request URL
|
||||
self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring
|
||||
return self.request_url
|
||||
|
||||
|
||||
def get_headers(header_name: str, header_value: str) -> bytearray:
|
||||
"""Build a header following AWS event stream format."""
|
||||
name = header_name.encode("utf-8")
|
||||
name_byte_length = bytes([len(name)])
|
||||
value_type = bytes([7]) # 7 represents a string
|
||||
value = header_value.encode("utf-8")
|
||||
value_byte_length = struct.pack(">H", len(value))
|
||||
|
||||
# Construct the header
|
||||
header_list = bytearray()
|
||||
header_list.extend(name_byte_length)
|
||||
header_list.extend(name)
|
||||
header_list.extend(value_type)
|
||||
header_list.extend(value_byte_length)
|
||||
header_list.extend(value)
|
||||
return header_list
|
||||
|
||||
|
||||
def build_event_message(payload: bytes) -> bytes:
|
||||
"""
|
||||
Build an event message for AWS Transcribe streaming.
|
||||
Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py
|
||||
"""
|
||||
# Build headers
|
||||
content_type_header = get_headers(":content-type", "application/octet-stream")
|
||||
event_type_header = get_headers(":event-type", "AudioEvent")
|
||||
message_type_header = get_headers(":message-type", "event")
|
||||
|
||||
headers = bytearray()
|
||||
headers.extend(content_type_header)
|
||||
headers.extend(event_type_header)
|
||||
headers.extend(message_type_header)
|
||||
|
||||
# Calculate total byte length and headers byte length
|
||||
# 16 accounts for 8 byte prelude, 2x 4 byte CRCs
|
||||
total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16)
|
||||
headers_byte_length = struct.pack(">I", len(headers))
|
||||
|
||||
# Build the prelude
|
||||
prelude = bytearray([0] * 8)
|
||||
prelude[:4] = total_byte_length
|
||||
prelude[4:] = headers_byte_length
|
||||
|
||||
# Calculate checksum for prelude
|
||||
prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF)
|
||||
|
||||
# Construct the message
|
||||
message_as_list = bytearray()
|
||||
message_as_list.extend(prelude)
|
||||
message_as_list.extend(prelude_crc)
|
||||
message_as_list.extend(headers)
|
||||
message_as_list.extend(payload)
|
||||
|
||||
# Calculate checksum for message
|
||||
message = bytes(message_as_list)
|
||||
message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF)
|
||||
|
||||
# Add message checksum
|
||||
message_as_list.extend(message_crc)
|
||||
|
||||
return bytes(message_as_list)
|
||||
|
||||
|
||||
def decode_event(message):
|
||||
# Extract the prelude, headers, payload and CRC
|
||||
prelude = message[:8]
|
||||
total_length, headers_length = struct.unpack(">II", prelude)
|
||||
prelude_crc = struct.unpack(">I", message[8:12])[0]
|
||||
headers = message[12 : 12 + headers_length]
|
||||
payload = message[12 + headers_length : -4]
|
||||
message_crc = struct.unpack(">I", message[-4:])[0]
|
||||
|
||||
# Check the CRCs
|
||||
assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed"
|
||||
assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed"
|
||||
|
||||
# Parse the headers
|
||||
headers_dict = {}
|
||||
while headers:
|
||||
name_len = headers[0]
|
||||
name = headers[1 : 1 + name_len].decode("utf-8")
|
||||
value_type = headers[1 + name_len]
|
||||
value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0]
|
||||
value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8")
|
||||
headers_dict[name] = value
|
||||
headers = headers[4 + name_len + value_len :]
|
||||
|
||||
return headers_dict, json.loads(payload)
|
||||
|
||||
|
||||
class TranscribeSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -355,17 +106,20 @@ class TranscribeSTTService(STTService):
|
||||
|
||||
raise RuntimeError("Failed to establish WebSocket connection after multiple attempts")
|
||||
|
||||
async def run_stt(self, frame: Frame) -> AsyncGenerator[Frame, None]:
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data and send to AWS Transcribe"""
|
||||
try:
|
||||
# Skip if no speech detected
|
||||
if hasattr(frame, "is_speech") and not frame.is_speech:
|
||||
logger.debug("Skipping non-speech frame")
|
||||
return
|
||||
|
||||
# Ensure WebSocket is connected
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
logger.info("WebSocket not connected, attempting to reconnect...")
|
||||
logger.debug("WebSocket not connected, attempting to reconnect...")
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
@@ -373,12 +127,8 @@ class TranscribeSTTService(STTService):
|
||||
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
|
||||
return
|
||||
|
||||
# Get the audio data - if frame is bytes, use directly, otherwise get audio attribute
|
||||
audio_data = frame if isinstance(frame, bytes) else frame.audio
|
||||
|
||||
# Format the audio data according to AWS event stream format
|
||||
event_message = build_event_message(audio_data)
|
||||
# logger.debug(f"Sending audio chunk of size {len(audio_data)} bytes")
|
||||
event_message = build_event_message(audio)
|
||||
|
||||
# Send the formatted event message
|
||||
try:
|
||||
@@ -402,23 +152,18 @@ class TranscribeSTTService(STTService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to AWS Transcribe with connection state management."""
|
||||
if (
|
||||
self._ws_client
|
||||
and self._ws_client.open
|
||||
and self._receive_task
|
||||
and not self._receive_task.done()
|
||||
):
|
||||
logger.debug("Already connected")
|
||||
if self._ws_client and self._ws_client.open and self._receive_task:
|
||||
logger.debug(f"{self} Already connected")
|
||||
return
|
||||
|
||||
async with self._connection_lock:
|
||||
if self._connecting:
|
||||
logger.debug("Connection already in progress")
|
||||
logger.debug(f"{self} Connection already in progress")
|
||||
return
|
||||
|
||||
try:
|
||||
self._connecting = True
|
||||
logger.debug("Starting connection process...")
|
||||
logger.debug(f"{self} Starting connection process...")
|
||||
|
||||
if self._ws_client:
|
||||
await self._disconnect()
|
||||
@@ -464,7 +209,7 @@ class TranscribeSTTService(STTService):
|
||||
enable_channel_identification=self._settings["enable_channel_identification"],
|
||||
)
|
||||
|
||||
logger.debug(f"Connecting to WebSocket with URL: {presigned_url[:100]}...")
|
||||
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
|
||||
|
||||
# Connect with the required headers and settings
|
||||
self._ws_client = await websockets.connect(
|
||||
@@ -475,15 +220,16 @@ class TranscribeSTTService(STTService):
|
||||
ping_timeout=None,
|
||||
compression=None,
|
||||
)
|
||||
logger.debug("WebSocket connected, starting receive task...")
|
||||
|
||||
logger.debug(f"{self} WebSocket connected, starting receive task...")
|
||||
|
||||
# Start receive task
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._receive_task = self.create_task(self._receive_loop())
|
||||
|
||||
logger.info("Successfully connected to AWS Transcribe")
|
||||
logger.info(f"{self} Successfully connected to AWS Transcribe")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to AWS Transcribe: {e}")
|
||||
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
|
||||
await self._disconnect()
|
||||
raise
|
||||
|
||||
@@ -493,24 +239,19 @@ class TranscribeSTTService(STTService):
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from AWS Transcribe."""
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
if self._ws_client:
|
||||
try:
|
||||
if self._ws_client.open:
|
||||
# Send end-stream message
|
||||
end_stream = {"message-type": "event", "event": "end"}
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing WebSocket connection: {e}")
|
||||
finally:
|
||||
self._ws_client = None
|
||||
try:
|
||||
if self._ws_client and self._ws_client.open:
|
||||
# Send end-stream message
|
||||
end_stream = {"message-type": "event", "event": "end"}
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} Error closing WebSocket connection: {e}")
|
||||
finally:
|
||||
self._ws_client = None
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
"""Convert internal language enum to AWS Transcribe language code."""
|
||||
@@ -529,72 +270,60 @@ class TranscribeSTTService(STTService):
|
||||
|
||||
async def _receive_loop(self):
|
||||
"""Background task to receive and process messages from AWS Transcribe."""
|
||||
try:
|
||||
logger.debug("Receive loop started")
|
||||
while True:
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
logger.warning("WebSocket closed in receive loop")
|
||||
break
|
||||
while True:
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
logger.warning(f"{self} WebSocket closed in receive loop")
|
||||
break
|
||||
|
||||
try:
|
||||
response = await self._ws_client.recv()
|
||||
headers, payload = decode_event(response)
|
||||
try:
|
||||
response = await self._ws_client.recv()
|
||||
headers, payload = decode_event(response)
|
||||
|
||||
# logger.debug(f"Received message type: {headers.get(':message-type')}")
|
||||
if headers.get(":message-type") == "event":
|
||||
# Process transcription results
|
||||
results = payload.get("Transcript", {}).get("Results", [])
|
||||
if results:
|
||||
result = results[0]
|
||||
alternatives = result.get("Alternatives", [])
|
||||
if alternatives:
|
||||
transcript = alternatives[0].get("Transcript", "")
|
||||
is_final = not result.get("IsPartial", True)
|
||||
|
||||
if headers.get(":message-type") == "event":
|
||||
# Process transcription results
|
||||
results = payload.get("Transcript", {}).get("Results", [])
|
||||
if results:
|
||||
result = results[0]
|
||||
alternatives = result.get("Alternatives", [])
|
||||
if alternatives:
|
||||
transcript = alternatives[0].get("Transcript", "")
|
||||
is_final = not result.get("IsPartial", True)
|
||||
|
||||
if transcript:
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
)
|
||||
if transcript:
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
)
|
||||
)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
"",
|
||||
time_now_iso8601(),
|
||||
self._settings["language"],
|
||||
)
|
||||
elif headers.get(":message-type") == "exception":
|
||||
error_msg = payload.get("Message", "Unknown error")
|
||||
logger.error(f"Exception from AWS: {error_msg}")
|
||||
await self.push_frame(
|
||||
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Other message type received: {headers}")
|
||||
logger.debug(f"Payload: {payload}")
|
||||
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.error(
|
||||
f"WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
|
||||
)
|
||||
elif headers.get(":message-type") == "exception":
|
||||
error_msg = payload.get("Message", "Unknown error")
|
||||
logger.error(f"{self} Exception from AWS: {error_msg}")
|
||||
await self.push_frame(
|
||||
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in receive loop: {e}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Receive loop cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in receive loop: {e}")
|
||||
finally:
|
||||
logger.debug("Receive loop ended")
|
||||
else:
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.error(
|
||||
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
break
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
@@ -19,7 +19,7 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
@@ -27,9 +27,7 @@ try:
|
||||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable."
|
||||
)
|
||||
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -206,7 +204,7 @@ class PollyTTSService(TTSService):
|
||||
|
||||
ssml += "</speak>"
|
||||
|
||||
logger.debug(f"SSML: {ssml}")
|
||||
logger.trace(f"{self} SSML: {ssml}")
|
||||
|
||||
return ssml
|
||||
|
||||
|
||||
Reference in New Issue
Block a user