Merge pull request #4083 from pipecat-ai/filipi/deepgram_sagemaker_tts_improvements

Improvements to DeepgramSageMakerTTSService
This commit is contained in:
Filipi da Silva Fuchter
2026-03-20 10:30:48 -04:00
committed by GitHub
2 changed files with 10 additions and 41 deletions

View File

@@ -0,0 +1 @@
- `DeepgramSageMakerTTSService` now correctly routes audio through the base `TTSService` audio context queue. Audio frames are delivered via `append_to_audio_context()` instead of being pushed directly, enabling proper ordering, interruption handling, and start/stop frame lifecycle management. Interruptions now trigger a `Clear` message to Deepgram (flushing its text buffer) at the right time via `on_audio_context_interrupted`.

View File

@@ -20,18 +20,13 @@ from typing import Any, AsyncGenerator, Optional
from loguru import logger
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
from pipecat.services.settings import TTSSettings
from pipecat.services.tts_service import TTSService
@@ -115,6 +110,7 @@ class DeepgramSageMakerTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
pause_frame_processing=True,
append_trailing_space=True,
@@ -128,8 +124,6 @@ class DeepgramSageMakerTTSService(TTSService):
self._client: Optional[SageMakerBidiClient] = None
self._response_task: Optional[asyncio.Task] = None
self._context_id: Optional[str] = None
self._ttfb_started: bool = False
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -166,20 +160,6 @@ class DeepgramSageMakerTTSService(TTSService):
await super().cancel(frame)
await self._disconnect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with special handling for LLM response end.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
elif isinstance(frame, BotStoppedSpeakingFrame):
self._ttfb_started = False
async def _connect(self):
"""Connect to the SageMaker endpoint and start the BiDi session.
@@ -301,13 +281,14 @@ class DeepgramSageMakerTTSService(TTSService):
except (UnicodeDecodeError, json.JSONDecodeError):
# Not JSON — treat as raw audio bytes
await self.stop_ttfb_metrics()
context_id = self.get_active_audio_context_id()
frame = TTSAudioRawFrame(
payload,
self.sample_rate,
1,
context_id=self._context_id,
context_id=context_id,
)
await self.push_frame(frame)
await self.append_to_audio_context(context_id, frame)
except asyncio.CancelledError:
logger.debug("TTS response processor cancelled")
@@ -316,15 +297,13 @@ class DeepgramSageMakerTTSService(TTSService):
finally:
logger.debug("TTS response processor stopped")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by sending Clear message to Deepgram.
async def on_audio_context_interrupted(self, context_id: str):
"""Called when an audio context is cancelled due to an interruption.
The Clear message will clear Deepgram's internal text buffer and stop
sending audio, allowing for a new response to be generated.
Args:
context_id: The ID of the audio context that was interrupted, or
``None`` if no context was active at the time.
"""
await super()._handle_interruption(frame, direction)
self._ttfb_started = False
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "Clear"})
@@ -356,19 +335,8 @@ class DeepgramSageMakerTTSService(TTSService):
the response processor).
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
if not self._ttfb_started:
await self.start_ttfb_metrics()
self._ttfb_started = True
yield TTSStartedFrame(context_id=context_id)
self._context_id = context_id
await self._client.send_json({"type": "Speak", "text": text})
yield None
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")