Merge pull request #4083 from pipecat-ai/filipi/deepgram_sagemaker_tts_improvements
Improvements to DeepgramSageMakerTTSService
This commit is contained in:
1
changelog/4083.changed.md
Normal file
1
changelog/4083.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- `DeepgramSageMakerTTSService` now correctly routes audio through the base `TTSService` audio context queue. Audio frames are delivered via `append_to_audio_context()` instead of being pushed directly, enabling proper ordering, interruption handling, and start/stop frame lifecycle management. Interruptions now trigger a `Clear` message to Deepgram (flushing its text buffer) at the right time via `on_audio_context_interrupted`.
|
||||
@@ -20,18 +20,13 @@ from typing import Any, AsyncGenerator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.settings import TTSSettings
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -115,6 +110,7 @@ class DeepgramSageMakerTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
@@ -128,8 +124,6 @@ class DeepgramSageMakerTTSService(TTSService):
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._context_id: Optional[str] = None
|
||||
self._ttfb_started: bool = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -166,20 +160,6 @@ class DeepgramSageMakerTTSService(TTSService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for LLM response end.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._ttfb_started = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
@@ -301,13 +281,14 @@ class DeepgramSageMakerTTSService(TTSService):
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Not JSON — treat as raw audio bytes
|
||||
await self.stop_ttfb_metrics()
|
||||
context_id = self.get_active_audio_context_id()
|
||||
frame = TTSAudioRawFrame(
|
||||
payload,
|
||||
self.sample_rate,
|
||||
1,
|
||||
context_id=self._context_id,
|
||||
context_id=context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("TTS response processor cancelled")
|
||||
@@ -316,15 +297,13 @@ class DeepgramSageMakerTTSService(TTSService):
|
||||
finally:
|
||||
logger.debug("TTS response processor stopped")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by sending Clear message to Deepgram.
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Called when an audio context is cancelled due to an interruption.
|
||||
|
||||
The Clear message will clear Deepgram's internal text buffer and stop
|
||||
sending audio, allowing for a new response to be generated.
|
||||
Args:
|
||||
context_id: The ID of the audio context that was interrupted, or
|
||||
``None`` if no context was active at the time.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._ttfb_started = False
|
||||
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Clear"})
|
||||
@@ -356,19 +335,8 @@ class DeepgramSageMakerTTSService(TTSService):
|
||||
the response processor).
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
if not self._ttfb_started:
|
||||
await self.start_ttfb_metrics()
|
||||
self._ttfb_started = True
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
|
||||
await self._client.send_json({"type": "Speak", "text": text})
|
||||
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
Reference in New Issue
Block a user