diff --git a/examples/foundational/07m-interruptible-aws.py b/examples/foundational/07m-interruptible-aws.py index ddb8b222e..c88439c62 100644 --- a/examples/foundational/07m-interruptible-aws.py +++ b/examples/foundational/07m-interruptible-aws.py @@ -13,7 +13,8 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.services.aws.llm import BedrockLLMContext, BedrockLLMService +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.aws.llm import BedrockLLMService from pipecat.services.aws.stt import TranscribeSTTService from pipecat.services.aws.tts import PollyTTSService from pipecat.transcriptions.language import Language @@ -55,15 +56,11 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac messages = [ { "role": "system", - "content": [ - { - "text": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way." - } - ], + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", }, ] - context = BedrockLLMContext(messages) + context = OpenAILLMContext(messages) context_aggregator = llm.create_context_aggregator(context) pipeline = Pipeline( @@ -92,14 +89,16 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac async def on_client_connected(transport, client): logger.info(f"Client connected") # Kick off the conversation. - messages.append( - {"role": "user", "content": [{"text": "Please introduce yourself to the user."}]} - ) + messages.append({"role": "user", "content": "Please introduce yourself to the user."}) await task.queue_frames([context_aggregator.user().get_context_frame()]) @transport.event_handler("on_client_disconnected") async def on_client_disconnected(transport, client): logger.info(f"Client disconnected") + + @transport.event_handler("on_client_closed") + async def on_client_closed(transport, client): + logger.info(f"Client closed connection") await task.cancel() runner = PipelineRunner(handle_sigint=False) diff --git a/pyproject.toml b/pyproject.toml index 910c8d066..13305933b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ Website = "https://pipecat.ai" [project.optional-dependencies] anthropic = [ "anthropic~=0.49.0" ] assemblyai = [ "assemblyai~=0.37.0" ] -aws = [ "boto3~=1.37.16" ] +aws = [ "boto3~=1.37.16", "websockets~=13.1" ] azure = [ "azure-cognitiveservices-speech~=1.42.0"] cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ] cerebras = [] diff --git a/src/pipecat/adapters/services/bedrock_adapter.py b/src/pipecat/adapters/services/bedrock_adapter.py index b877f01fc..cfb2a5f27 100644 --- a/src/pipecat/adapters/services/bedrock_adapter.py +++ b/src/pipecat/adapters/services/bedrock_adapter.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from pipecat.adapters.base_llm_adapter import BaseLLMAdapter from pipecat.adapters.schemas.function_schema import FunctionSchema diff --git a/src/pipecat/services/aws/__init__.py b/src/pipecat/services/aws/__init__.py index b36c88499..b1f157bd3 100644 --- a/src/pipecat/services/aws/__init__.py +++ b/src/pipecat/services/aws/__init__.py @@ -8,6 +8,8 @@ import sys from pipecat.services import DeprecatedModuleProxy +from .llm import * +from .stt import * from .tts import * -sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts") +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.[llm,stt,tts]") diff --git a/src/pipecat/services/aws/llm.py b/src/pipecat/services/aws/llm.py index 3b9c1fedd..63b0964c2 100644 --- a/src/pipecat/services/aws/llm.py +++ b/src/pipecat/services/aws/llm.py @@ -11,16 +11,12 @@ import io import json import re from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Optional -import boto3 -from botocore.config import Config -import httpx from loguru import logger from PIL import Image from pydantic import BaseModel, Field -from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter from pipecat.frames.frames import ( Frame, FunctionCallCancelFrame, @@ -36,7 +32,9 @@ from pipecat.frames.frames import ( ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_response import ( + LLMAssistantAggregatorParams, LLMAssistantContextAggregator, + LLMUserAggregatorParams, LLMUserContextAggregator, ) from pipecat.processors.aggregators.openai_llm_context import ( @@ -44,7 +42,18 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContextFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import LLMService +from pipecat.services.llm_service import LLMService + +try: + import boto3 + import httpx + from botocore.config import Config +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable." + ) + raise Exception(f"Missing module: {e}") @dataclass @@ -564,10 +573,10 @@ class BedrockLLMService(LLMService): def create_context_aggregator( self, - context: BedrockLLMContext, + context: OpenAILLMContext, *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, + user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), + assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), ) -> BedrockContextAggregatorPair: """Create an instance of BedrockContextAggregatorPair from an OpenAILLMContext. Constructor keyword arguments for both the user and @@ -575,12 +584,10 @@ class BedrockLLMService(LLMService): Args: context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. + user_params (LLMUserAggregatorParams, optional): User aggregator + parameters. + assistant_params (LLMAssistantAggregatorParams, optional): User + aggregator parameters. Returns: BedrockContextAggregatorPair: A pair of context aggregators, one @@ -589,11 +596,11 @@ class BedrockLLMService(LLMService): """ context.set_llm_adapter(self.get_llm_adapter()) - if isinstance(context, OpenAILLMContext) and not isinstance(context, BedrockLLMContext): + if isinstance(context, OpenAILLMContext): context = BedrockLLMContext.from_openai_context(context) - user = BedrockUserContextAggregator(context, **user_kwargs) - assistant = BedrockAssistantContextAggregator(context, **assistant_kwargs) + user = BedrockUserContextAggregator(context, params=user_params) + assistant = BedrockAssistantContextAggregator(context, params=assistant_params) return BedrockContextAggregatorPair(_user=user, _assistant=assistant) async def _process_context(self, context: BedrockLLMContext): diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py index d749eff0c..0468ab31b 100644 --- a/src/pipecat/services/aws/stt.py +++ b/src/pipecat/services/aws/stt.py @@ -1,289 +1,40 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio -from typing import AsyncGenerator, Optional, Dict -import os -import datetime -from urllib.parse import urlencode import json -import struct -import urllib.parse -import hashlib -import hmac +import os import random import string -import binascii +from typing import AsyncGenerator, Optional from loguru import logger from pipecat.frames.frames import ( + CancelFrame, + EndFrame, ErrorFrame, Frame, - TranscriptionFrame, InterimTranscriptionFrame, StartFrame, + TranscriptionFrame, ) -from pipecat.services.ai_services import STTService +from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url +from pipecat.services.stt_service import STTService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 try: - import boto3 - from botocore.exceptions import BotoCoreError, ClientError import websockets except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable." - ) + logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.") raise Exception(f"Missing module: {e}") -def get_presigned_url( - *, - region: str, - credentials: Dict[str, Optional[str]], - language_code: str, - media_encoding: str = "pcm", - sample_rate: int = 16000, - number_of_channels: int = 1, - enable_partial_results_stabilization: bool = True, - partial_results_stability: str = "high", - vocabulary_name: Optional[str] = None, - vocabulary_filter_name: Optional[str] = None, - show_speaker_label: bool = False, - enable_channel_identification: bool = False, -) -> str: - """Create a presigned URL for AWS Transcribe streaming.""" - access_key = credentials.get("access_key") - secret_key = credentials.get("secret_key") - session_token = credentials.get("session_token") - - if not access_key or not secret_key: - raise ValueError("AWS credentials are required") - - # Initialize the URL generator - url_generator = AWSTranscribePresignedURL( - access_key=access_key, secret_key=secret_key, session_token=session_token, region=region - ) - - # Get the presigned URL - return url_generator.get_request_url( - sample_rate=sample_rate, - language_code=language_code, - media_encoding=media_encoding, - vocabulary_name=vocabulary_name, - vocabulary_filter_name=vocabulary_filter_name, - show_speaker_label=show_speaker_label, - enable_channel_identification=enable_channel_identification, - number_of_channels=number_of_channels, - enable_partial_results_stabilization=enable_partial_results_stabilization, - partial_results_stability=partial_results_stability, - ) - - -class AWSTranscribePresignedURL: - def __init__( - self, access_key: str, secret_key: str, session_token: str, region: str = "us-east-1" - ): - self.access_key = access_key - self.secret_key = secret_key - self.session_token = session_token - self.method = "GET" - self.service = "transcribe" - self.region = region - self.endpoint = "" - self.host = "" - self.amz_date = "" - self.datestamp = "" - self.canonical_uri = "/stream-transcription-websocket" - self.canonical_headers = "" - self.signed_headers = "host" - self.algorithm = "AWS4-HMAC-SHA256" - self.credential_scope = "" - self.canonical_querystring = "" - self.payload_hash = "" - self.canonical_request = "" - self.string_to_sign = "" - self.signature = "" - self.request_url = "" - - def get_request_url( - self, - sample_rate: int, - language_code: str = "", - media_encoding: str = "pcm", - vocabulary_name: str = "", - vocabulary_filter_name: str = "", - show_speaker_label: bool = False, - enable_channel_identification: bool = False, - number_of_channels: int = 1, - enable_partial_results_stabilization: bool = False, - partial_results_stability: str = "", - ) -> str: - self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443" - self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443" - - now = datetime.datetime.utcnow() - self.amz_date = now.strftime("%Y%m%dT%H%M%SZ") - self.datestamp = now.strftime("%Y%m%d") - self.canonical_headers = f"host:{self.host}\n" - self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request" - - # Create canonical querystring - self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm - self.canonical_querystring += ( - "&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope - ) - self.canonical_querystring += "&X-Amz-Date=" + self.amz_date - self.canonical_querystring += "&X-Amz-Expires=300" - if self.session_token: - self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote( - self.session_token, safe="" - ) - self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers - - if enable_channel_identification: - self.canonical_querystring += "&enable-channel-identification=true" - if enable_partial_results_stabilization: - self.canonical_querystring += "&enable-partial-results-stabilization=true" - if language_code: - self.canonical_querystring += "&language-code=" + language_code - if media_encoding: - self.canonical_querystring += "&media-encoding=" + media_encoding - if number_of_channels > 1: - self.canonical_querystring += "&number-of-channels=" + str(number_of_channels) - if partial_results_stability: - self.canonical_querystring += "&partial-results-stability=" + partial_results_stability - if sample_rate: - self.canonical_querystring += "&sample-rate=" + str(sample_rate) - if show_speaker_label: - self.canonical_querystring += "&show-speaker-label=true" - if vocabulary_filter_name: - self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name - if vocabulary_name: - self.canonical_querystring += "&vocabulary-name=" + vocabulary_name - - # Create payload hash - self.payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest() - - # Create canonical request - self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}" - - # Create string to sign - credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request" - string_to_sign = ( - f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n" - + hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest() - ) - - # Calculate signature - k_date = hmac.new( - f"AWS4{self.secret_key}".encode("utf-8"), self.datestamp.encode("utf-8"), hashlib.sha256 - ).digest() - k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest() - k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest() - k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest() - self.signature = hmac.new( - k_signing, string_to_sign.encode("utf-8"), hashlib.sha256 - ).hexdigest() - - # Add signature to query string - self.canonical_querystring += "&X-Amz-Signature=" + self.signature - - # Create request URL - self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring - return self.request_url - - -def get_headers(header_name: str, header_value: str) -> bytearray: - """Build a header following AWS event stream format.""" - name = header_name.encode("utf-8") - name_byte_length = bytes([len(name)]) - value_type = bytes([7]) # 7 represents a string - value = header_value.encode("utf-8") - value_byte_length = struct.pack(">H", len(value)) - - # Construct the header - header_list = bytearray() - header_list.extend(name_byte_length) - header_list.extend(name) - header_list.extend(value_type) - header_list.extend(value_byte_length) - header_list.extend(value) - return header_list - - -def build_event_message(payload: bytes) -> bytes: - """ - Build an event message for AWS Transcribe streaming. - Matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py - """ - # Build headers - content_type_header = get_headers(":content-type", "application/octet-stream") - event_type_header = get_headers(":event-type", "AudioEvent") - message_type_header = get_headers(":message-type", "event") - - headers = bytearray() - headers.extend(content_type_header) - headers.extend(event_type_header) - headers.extend(message_type_header) - - # Calculate total byte length and headers byte length - # 16 accounts for 8 byte prelude, 2x 4 byte CRCs - total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16) - headers_byte_length = struct.pack(">I", len(headers)) - - # Build the prelude - prelude = bytearray([0] * 8) - prelude[:4] = total_byte_length - prelude[4:] = headers_byte_length - - # Calculate checksum for prelude - prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) - - # Construct the message - message_as_list = bytearray() - message_as_list.extend(prelude) - message_as_list.extend(prelude_crc) - message_as_list.extend(headers) - message_as_list.extend(payload) - - # Calculate checksum for message - message = bytes(message_as_list) - message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF) - - # Add message checksum - message_as_list.extend(message_crc) - - return bytes(message_as_list) - - -def decode_event(message): - # Extract the prelude, headers, payload and CRC - prelude = message[:8] - total_length, headers_length = struct.unpack(">II", prelude) - prelude_crc = struct.unpack(">I", message[8:12])[0] - headers = message[12 : 12 + headers_length] - payload = message[12 + headers_length : -4] - message_crc = struct.unpack(">I", message[-4:])[0] - - # Check the CRCs - assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed" - assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed" - - # Parse the headers - headers_dict = {} - while headers: - name_len = headers[0] - name = headers[1 : 1 + name_len].decode("utf-8") - value_type = headers[1 + name_len] - value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0] - value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8") - headers_dict[name] = value - headers = headers[4 + name_len + value_len :] - - return headers_dict, json.loads(payload) - - class TranscribeSTTService(STTService): def __init__( self, @@ -355,17 +106,20 @@ class TranscribeSTTService(STTService): raise RuntimeError("Failed to establish WebSocket connection after multiple attempts") - async def run_stt(self, frame: Frame) -> AsyncGenerator[Frame, None]: + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Process audio data and send to AWS Transcribe""" try: - # Skip if no speech detected - if hasattr(frame, "is_speech") and not frame.is_speech: - logger.debug("Skipping non-speech frame") - return - # Ensure WebSocket is connected if not self._ws_client or not self._ws_client.open: - logger.info("WebSocket not connected, attempting to reconnect...") + logger.debug("WebSocket not connected, attempting to reconnect...") try: await self._connect() except Exception as e: @@ -373,12 +127,8 @@ class TranscribeSTTService(STTService): yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False) return - # Get the audio data - if frame is bytes, use directly, otherwise get audio attribute - audio_data = frame if isinstance(frame, bytes) else frame.audio - # Format the audio data according to AWS event stream format - event_message = build_event_message(audio_data) - # logger.debug(f"Sending audio chunk of size {len(audio_data)} bytes") + event_message = build_event_message(audio) # Send the formatted event message try: @@ -402,23 +152,18 @@ class TranscribeSTTService(STTService): async def _connect(self): """Connect to AWS Transcribe with connection state management.""" - if ( - self._ws_client - and self._ws_client.open - and self._receive_task - and not self._receive_task.done() - ): - logger.debug("Already connected") + if self._ws_client and self._ws_client.open and self._receive_task: + logger.debug(f"{self} Already connected") return async with self._connection_lock: if self._connecting: - logger.debug("Connection already in progress") + logger.debug(f"{self} Connection already in progress") return try: self._connecting = True - logger.debug("Starting connection process...") + logger.debug(f"{self} Starting connection process...") if self._ws_client: await self._disconnect() @@ -464,7 +209,7 @@ class TranscribeSTTService(STTService): enable_channel_identification=self._settings["enable_channel_identification"], ) - logger.debug(f"Connecting to WebSocket with URL: {presigned_url[:100]}...") + logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...") # Connect with the required headers and settings self._ws_client = await websockets.connect( @@ -475,15 +220,16 @@ class TranscribeSTTService(STTService): ping_timeout=None, compression=None, ) - logger.debug("WebSocket connected, starting receive task...") + + logger.debug(f"{self} WebSocket connected, starting receive task...") # Start receive task - self._receive_task = asyncio.create_task(self._receive_loop()) + self._receive_task = self.create_task(self._receive_loop()) - logger.info("Successfully connected to AWS Transcribe") + logger.info(f"{self} Successfully connected to AWS Transcribe") except Exception as e: - logger.error(f"Failed to connect to AWS Transcribe: {e}") + logger.error(f"{self} Failed to connect to AWS Transcribe: {e}") await self._disconnect() raise @@ -493,24 +239,19 @@ class TranscribeSTTService(STTService): async def _disconnect(self): """Disconnect from AWS Transcribe.""" if self._receive_task: - self._receive_task.cancel() - try: - await self._receive_task - except asyncio.CancelledError: - pass + await self.cancel_task(self._receive_task) self._receive_task = None - if self._ws_client: - try: - if self._ws_client.open: - # Send end-stream message - end_stream = {"message-type": "event", "event": "end"} - await self._ws_client.send(json.dumps(end_stream)) - await self._ws_client.close() - except Exception as e: - logger.warning(f"Error closing WebSocket connection: {e}") - finally: - self._ws_client = None + try: + if self._ws_client and self._ws_client.open: + # Send end-stream message + end_stream = {"message-type": "event", "event": "end"} + await self._ws_client.send(json.dumps(end_stream)) + await self._ws_client.close() + except Exception as e: + logger.warning(f"{self} Error closing WebSocket connection: {e}") + finally: + self._ws_client = None def language_to_service_language(self, language: Language) -> str | None: """Convert internal language enum to AWS Transcribe language code.""" @@ -529,72 +270,60 @@ class TranscribeSTTService(STTService): async def _receive_loop(self): """Background task to receive and process messages from AWS Transcribe.""" - try: - logger.debug("Receive loop started") - while True: - if not self._ws_client or not self._ws_client.open: - logger.warning("WebSocket closed in receive loop") - break + while True: + if not self._ws_client or not self._ws_client.open: + logger.warning(f"{self} WebSocket closed in receive loop") + break - try: - response = await self._ws_client.recv() - headers, payload = decode_event(response) + try: + response = await self._ws_client.recv() + headers, payload = decode_event(response) - # logger.debug(f"Received message type: {headers.get(':message-type')}") + if headers.get(":message-type") == "event": + # Process transcription results + results = payload.get("Transcript", {}).get("Results", []) + if results: + result = results[0] + alternatives = result.get("Alternatives", []) + if alternatives: + transcript = alternatives[0].get("Transcript", "") + is_final = not result.get("IsPartial", True) - if headers.get(":message-type") == "event": - # Process transcription results - results = payload.get("Transcript", {}).get("Results", []) - if results: - result = results[0] - alternatives = result.get("Alternatives", []) - if alternatives: - transcript = alternatives[0].get("Transcript", "") - is_final = not result.get("IsPartial", True) - - if transcript: - await self.stop_ttfb_metrics() - if is_final: - await self.push_frame( - TranscriptionFrame( - transcript, - "", - time_now_iso8601(), - self._settings["language"], - ) + if transcript: + await self.stop_ttfb_metrics() + if is_final: + await self.push_frame( + TranscriptionFrame( + transcript, + "", + time_now_iso8601(), + self._settings["language"], ) - await self.stop_processing_metrics() - else: - await self.push_frame( - InterimTranscriptionFrame( - transcript, - "", - time_now_iso8601(), - self._settings["language"], - ) + ) + await self.stop_processing_metrics() + else: + await self.push_frame( + InterimTranscriptionFrame( + transcript, + "", + time_now_iso8601(), + self._settings["language"], ) - elif headers.get(":message-type") == "exception": - error_msg = payload.get("Message", "Unknown error") - logger.error(f"Exception from AWS: {error_msg}") - await self.push_frame( - ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False) - ) - else: - logger.debug(f"Other message type received: {headers}") - logger.debug(f"Payload: {payload}") - - except websockets.exceptions.ConnectionClosed as e: - logger.error( - f"WebSocket connection closed in receive loop with code {e.code}: {e.reason}" + ) + elif headers.get(":message-type") == "exception": + error_msg = payload.get("Message", "Unknown error") + logger.error(f"{self} Exception from AWS: {error_msg}") + await self.push_frame( + ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False) ) - break - except Exception as e: - logger.error(f"Error in receive loop: {e}") - break - - except asyncio.CancelledError: - logger.debug("Receive loop cancelled") - except Exception as e: - logger.error(f"Unexpected error in receive loop: {e}") - finally: - logger.debug("Receive loop ended") + else: + logger.debug(f"{self} Other message type received: {headers}") + logger.debug(f"{self} Payload: {payload}") + except websockets.exceptions.ConnectionClosed as e: + logger.error( + f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}" + ) + break + except Exception as e: + logger.error(f"{self} Unexpected error in receive loop: {e}") + break diff --git a/src/pipecat/services/aws/tts.py b/src/pipecat/services/aws/tts.py index d61f74ab2..0fdbb8273 100644 --- a/src/pipecat/services/aws/tts.py +++ b/src/pipecat/services/aws/tts.py @@ -5,8 +5,8 @@ # import asyncio -from typing import AsyncGenerator, Optional import os +from typing import AsyncGenerator, Optional from loguru import logger from pydantic import BaseModel @@ -19,7 +19,7 @@ from pipecat.frames.frames import ( TTSStartedFrame, TTSStoppedFrame, ) -from pipecat.services.ai_services import TTSService +from pipecat.services.tts_service import TTSService from pipecat.transcriptions.language import Language try: @@ -27,9 +27,7 @@ try: from botocore.exceptions import BotoCoreError, ClientError except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use AWS services, you need to `pip install pipecat-ai[aws]`. Also, remember to set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable." - ) + logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.") raise Exception(f"Missing module: {e}") @@ -206,7 +204,7 @@ class PollyTTSService(TTSService): ssml += "" - logger.debug(f"SSML: {ssml}") + logger.trace(f"{self} SSML: {ssml}") return ssml