Compare commits
3 Commits
v1.2.1
...
filipi/sma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3954f96a6b | ||
|
|
6887ac394a | ||
|
|
4f034d4d4e |
@@ -6,8 +6,13 @@
|
||||
# For a full example of how to deploy to SageMaker, see:
|
||||
# https://github.com/pipecat-ai/pipecat-examples/tree/main/nvidia_sagemaker_example/deployment/aws-sagemaker-nvidia
|
||||
|
||||
import os
|
||||
|
||||
import datetime
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
|
||||
import aiofiles
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
@@ -21,6 +26,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
@@ -32,6 +38,21 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def save_audio_file(audio: bytes, filename: str, sample_rate: int, num_channels: int):
|
||||
"""Save audio data to a WAV file."""
|
||||
if len(audio) > 0:
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
async with aiofiles.open(filename, "wb") as file:
|
||||
await file.write(buffer.getvalue())
|
||||
logger.info(f"Audio saved to {filename}")
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
@@ -70,6 +91,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
endpoint_name=os.environ["SAGEMAKER_MAGPIE_ENDPOINT_NAME"],
|
||||
region=os.getenv("AWS_REGION", "us-west-2"),
|
||||
)
|
||||
audiobuffer = AudioBufferProcessor()
|
||||
|
||||
context = LLMContext()
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
@@ -85,6 +107,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
audiobuffer, # Audio buffer for recording
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
@@ -101,6 +124,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Start recording audio
|
||||
await audiobuffer.start_recording()
|
||||
# Kick off the conversation.
|
||||
context.add_message(
|
||||
{"role": "developer", "content": "Please introduce yourself to the user."}
|
||||
@@ -112,6 +137,26 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
# Handler for merged audio
|
||||
@audiobuffer.event_handler("on_audio_data")
|
||||
async def on_audio_data(buffer, audio, sample_rate, num_channels):
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"recordings/merged_{timestamp}.wav"
|
||||
os.makedirs("recordings", exist_ok=True)
|
||||
await save_audio_file(audio, filename, sample_rate, num_channels)
|
||||
|
||||
# Handler for separate tracks
|
||||
@audiobuffer.event_handler("on_track_audio_data")
|
||||
async def on_track_audio_data(buffer, user_audio, bot_audio, sample_rate, num_channels):
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
os.makedirs("recordings", exist_ok=True)
|
||||
|
||||
user_filename = f"recordings/user_{timestamp}.wav"
|
||||
await save_audio_file(user_audio, user_filename, sample_rate, 1)
|
||||
|
||||
bot_filename = f"recordings/bot_{timestamp}.wav"
|
||||
await save_audio_file(bot_audio, bot_filename, sample_rate, 1)
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
@@ -280,6 +280,8 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
self._client: SageMakerBidiClient | None = None
|
||||
self._receive_task = None
|
||||
self._speech_completed_event = asyncio.Event()
|
||||
self._audio_buffer = b""
|
||||
self._playback_started = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -377,7 +379,12 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
logger.info(f"{self}: verifying if websocket connection is active {active}")
|
||||
return active
|
||||
|
||||
def _reset_audio_buffer(self):
|
||||
self._audio_buffer = b""
|
||||
self._playback_started = False
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
self._reset_audio_buffer()
|
||||
if self._bot_speaking and self._client:
|
||||
logger.debug(
|
||||
f"{self}: interruption detected, sending input_text.done and waiting for speech.completed"
|
||||
@@ -391,6 +398,30 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
logger.warning(f"{self}: timed out waiting for conversation.item.speech.completed")
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
async def _handle_audio_chunk(self, audio: bytes, context_id: str | None = None):
|
||||
"""Buffer audio and emit frames using a jitter-buffer approach.
|
||||
|
||||
Holds back audio until chunk_size bytes have been accumulated (to avoid
|
||||
glitches at the start of playback), then emits each subsequent chunk
|
||||
immediately as it arrives.
|
||||
"""
|
||||
self._audio_buffer += audio
|
||||
|
||||
if not self._playback_started:
|
||||
if len(self._audio_buffer) < self.chunk_size:
|
||||
return
|
||||
self._playback_started = True
|
||||
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=self._audio_buffer,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
self._audio_buffer = b""
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive NIM JSON events and push audio frames."""
|
||||
while self._client and self._client.is_active and not self._disconnecting:
|
||||
@@ -415,14 +446,7 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
msg = json.loads(payload.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Unexpected binary frame — treat as raw PCM
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=payload,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
await self._handle_audio_chunk(payload, context_id)
|
||||
continue
|
||||
|
||||
event_type = msg.get("type", "")
|
||||
@@ -434,14 +458,7 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
chunk_b64 = msg.get("audio", "")
|
||||
if chunk_b64:
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=base64.b64decode(chunk_b64),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
await self._handle_audio_chunk(base64.b64decode(chunk_b64), context_id)
|
||||
elif event_type == "error":
|
||||
await self.push_error(error_msg=f"NIM error: {msg.get('message', msg)}")
|
||||
# In case of error we need to reconnect, otherwise we are not going to receive audio from the TTS service anymore
|
||||
|
||||
@@ -35,6 +35,8 @@ from pipecat.frames.frames import (
|
||||
OutputTransportMessageUrgentFrame,
|
||||
SpriteFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
@@ -97,6 +99,7 @@ class RawAudioTrack(AudioStreamTrack):
|
||||
self._start = time.time()
|
||||
# Queue of (bytes, future), broken into 10ms sub chunks as needed
|
||||
self._chunk_queue = deque()
|
||||
self._is_bot_speaking = False
|
||||
|
||||
def add_audio_bytes(self, audio_bytes: bytes):
|
||||
"""Add audio bytes to the buffer for transmission.
|
||||
@@ -123,6 +126,14 @@ class RawAudioTrack(AudioStreamTrack):
|
||||
|
||||
return future
|
||||
|
||||
def set_is_bot_speaking(self, value: bool):
|
||||
"""Set whether the bot is currently speaking.
|
||||
|
||||
Args:
|
||||
value: True if the bot has started speaking, False when it has stopped.
|
||||
"""
|
||||
self._is_bot_speaking = value
|
||||
|
||||
async def recv(self):
|
||||
"""Return the next audio frame for WebRTC transmission.
|
||||
|
||||
@@ -137,7 +148,12 @@ class RawAudioTrack(AudioStreamTrack):
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
if not self._chunk_queue:
|
||||
if self._auto_silence:
|
||||
# Injecting silence while the bot is speaking can cause audible glitches:
|
||||
# TTS audio arrives in bursts, and a silence frame inserted between two
|
||||
# consecutive TTS chunks will produce a brief gap or pop in the output.
|
||||
if self._auto_silence and not self._is_bot_speaking:
|
||||
#if self._is_bot_speaking:
|
||||
# logger.warning("Injecting silence while bot is speaking can cause glitches in the audio.")
|
||||
chunk = bytes(self._bytes_per_10ms)
|
||||
else:
|
||||
while not self._chunk_queue:
|
||||
@@ -426,6 +442,15 @@ class SmallWebRTCClient:
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_is_bot_speaking(self, value: bool):
|
||||
"""Propagate bot speaking state to the audio output track.
|
||||
|
||||
Args:
|
||||
value: True if the bot has started speaking, False when it has stopped.
|
||||
"""
|
||||
if self._audio_output_track:
|
||||
self._audio_output_track.set_is_bot_speaking(value)
|
||||
|
||||
async def write_video_frame(self, frame: OutputImageRawFrame) -> bool:
|
||||
"""Write a video frame to the WebRTC connection.
|
||||
|
||||
@@ -861,6 +886,13 @@ class SmallWebRTCOutputTransport(BaseOutputTransport):
|
||||
Returns:
|
||||
True if the audio frame was written successfully, False otherwise.
|
||||
"""
|
||||
# Track when the bot is speaking so the audio track can avoid injecting
|
||||
# silence between TTS chunks, which would cause audible glitches.
|
||||
# Using the TTSAudioRawFrame as reference since we can receive
|
||||
# TTSStartedFrame a few hundred milliseconds before actually start
|
||||
# receiving the audio
|
||||
if isinstance(frame, TTSAudioRawFrame):
|
||||
self._client.set_is_bot_speaking(True)
|
||||
return await self._client.write_audio_frame(frame)
|
||||
|
||||
async def write_video_frame(self, frame: OutputImageRawFrame) -> bool:
|
||||
@@ -874,6 +906,20 @@ class SmallWebRTCOutputTransport(BaseOutputTransport):
|
||||
"""
|
||||
return await self._client.write_video_frame(frame)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and handle transport-specific logic.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Track when the bot is speaking so the audio track can avoid injecting
|
||||
# silence between TTS chunks, which would cause audible glitches.
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
self._client.set_is_bot_speaking(False)
|
||||
|
||||
|
||||
class SmallWebRTCTransport(BaseTransport):
|
||||
"""WebRTC transport implementation for real-time communication.
|
||||
|
||||
Reference in New Issue
Block a user