Recording audio in the nvidia sagemaker example.

This commit is contained in:
filipi87
2026-05-13 07:55:52 -03:00
parent 9148e307cc
commit 4f034d4d4e

View File

@@ -6,8 +6,13 @@
# For a full example of how to deploy to SageMaker, see:
# https://github.com/pipecat-ai/pipecat-examples/tree/main/nvidia_sagemaker_example/deployment/aws-sagemaker-nvidia
import os
import datetime
import io
import os
import wave
import aiofiles
from dotenv import load_dotenv
from loguru import logger
@@ -21,6 +26,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.nvidia.llm import NvidiaLLMService
@@ -32,6 +38,21 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
async def save_audio_file(audio: bytes, filename: str, sample_rate: int, num_channels: int):
"""Save audio data to a WAV file."""
if len(audio) > 0:
with io.BytesIO() as buffer:
with wave.open(buffer, "wb") as wf:
wf.setsampwidth(2)
wf.setnchannels(num_channels)
wf.setframerate(sample_rate)
wf.writeframes(audio)
async with aiofiles.open(filename, "wb") as file:
await file.write(buffer.getvalue())
logger.info(f"Audio saved to {filename}")
# We use lambdas to defer transport parameter creation until the transport
# type is selected at runtime.
transport_params = {
@@ -70,6 +91,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
endpoint_name=os.environ["SAGEMAKER_MAGPIE_ENDPOINT_NAME"],
region=os.getenv("AWS_REGION", "us-west-2"),
)
audiobuffer = AudioBufferProcessor()
context = LLMContext()
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
@@ -85,6 +107,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
audiobuffer, # Audio buffer for recording
assistant_aggregator, # Assistant spoken responses
]
)
@@ -101,6 +124,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Start recording audio
await audiobuffer.start_recording()
# Kick off the conversation.
context.add_message(
{"role": "developer", "content": "Please introduce yourself to the user."}
@@ -112,6 +137,26 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Client disconnected")
await task.cancel()
# Handler for merged audio
@audiobuffer.event_handler("on_audio_data")
async def on_audio_data(buffer, audio, sample_rate, num_channels):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"recordings/merged_{timestamp}.wav"
os.makedirs("recordings", exist_ok=True)
await save_audio_file(audio, filename, sample_rate, num_channels)
# Handler for separate tracks
@audiobuffer.event_handler("on_track_audio_data")
async def on_track_audio_data(buffer, user_audio, bot_audio, sample_rate, num_channels):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("recordings", exist_ok=True)
user_filename = f"recordings/user_{timestamp}.wav"
await save_audio_file(user_audio, user_filename, sample_rate, 1)
bot_filename = f"recordings/bot_{timestamp}.wav"
await save_audio_file(bot_audio, bot_filename, sample_rate, 1)
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)