Merge pull request #4477 from pipecat-ai/filipi/nvidia_sagemaker_follow_up
NVidia TTS Sagemaker: Buffering audio to avoid glitches.
This commit is contained in:
@@ -280,6 +280,8 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
self._client: SageMakerBidiClient | None = None
|
||||
self._receive_task = None
|
||||
self._speech_completed_event = asyncio.Event()
|
||||
self._audio_buffer = b""
|
||||
self._playback_started = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -377,7 +379,12 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
logger.info(f"{self}: verifying if websocket connection is active {active}")
|
||||
return active
|
||||
|
||||
def _reset_audio_buffer(self):
|
||||
self._audio_buffer = b""
|
||||
self._playback_started = False
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
self._reset_audio_buffer()
|
||||
if self._bot_speaking and self._client:
|
||||
logger.debug(
|
||||
f"{self}: interruption detected, sending input_text.done and waiting for speech.completed"
|
||||
@@ -391,6 +398,30 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
logger.warning(f"{self}: timed out waiting for conversation.item.speech.completed")
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
async def _handle_audio_chunk(self, audio: bytes, context_id: str | None = None):
|
||||
"""Buffer audio and emit frames using a jitter-buffer approach.
|
||||
|
||||
Holds back audio until chunk_size bytes have been accumulated (to avoid
|
||||
glitches at the start of playback), then emits each subsequent chunk
|
||||
immediately as it arrives.
|
||||
"""
|
||||
self._audio_buffer += audio
|
||||
|
||||
if not self._playback_started:
|
||||
if len(self._audio_buffer) < self.chunk_size:
|
||||
return
|
||||
self._playback_started = True
|
||||
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=self._audio_buffer,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
self._audio_buffer = b""
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive NIM JSON events and push audio frames."""
|
||||
while self._client and self._client.is_active and not self._disconnecting:
|
||||
@@ -415,14 +446,7 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
msg = json.loads(payload.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Unexpected binary frame — treat as raw PCM
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=payload,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
await self._handle_audio_chunk(payload, context_id)
|
||||
continue
|
||||
|
||||
event_type = msg.get("type", "")
|
||||
@@ -434,14 +458,7 @@ class NvidiaSageMakerTTSService(InterruptibleTTSService):
|
||||
chunk_b64 = msg.get("audio", "")
|
||||
if chunk_b64:
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=base64.b64decode(chunk_b64),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
await self._handle_audio_chunk(base64.b64decode(chunk_b64), context_id)
|
||||
elif event_type == "error":
|
||||
await self.push_error(error_msg=f"NIM error: {msg.get('message', msg)}")
|
||||
# In case of error we need to reconnect, otherwise we are not going to receive audio from the TTS service anymore
|
||||
|
||||
Reference in New Issue
Block a user