diff --git a/changelog/4083.changed.md b/changelog/4083.changed.md new file mode 100644 index 000000000..d9d46957a --- /dev/null +++ b/changelog/4083.changed.md @@ -0,0 +1 @@ +- `DeepgramSageMakerTTSService` now correctly routes audio through the base `TTSService` audio context queue. Audio frames are delivered via `append_to_audio_context()` instead of being pushed directly, enabling proper ordering, interruption handling, and start/stop frame lifecycle management. Interruptions now trigger a `Clear` message to Deepgram (flushing its text buffer) at the right time via `on_audio_context_interrupted`. diff --git a/src/pipecat/services/deepgram/sagemaker/tts.py b/src/pipecat/services/deepgram/sagemaker/tts.py index 70be4a00e..36541b4be 100644 --- a/src/pipecat/services/deepgram/sagemaker/tts.py +++ b/src/pipecat/services/deepgram/sagemaker/tts.py @@ -20,18 +20,13 @@ from typing import Any, AsyncGenerator, Optional from loguru import logger from pipecat.frames.frames import ( - BotStoppedSpeakingFrame, CancelFrame, EndFrame, ErrorFrame, Frame, - InterruptionFrame, - LLMFullResponseEndFrame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, ) -from pipecat.processors.frame_processor import FrameDirection from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient from pipecat.services.settings import TTSSettings from pipecat.services.tts_service import TTSService @@ -115,6 +110,7 @@ class DeepgramSageMakerTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, push_stop_frames=True, pause_frame_processing=True, append_trailing_space=True, @@ -128,8 +124,6 @@ class DeepgramSageMakerTTSService(TTSService): self._client: Optional[SageMakerBidiClient] = None self._response_task: Optional[asyncio.Task] = None - self._context_id: Optional[str] = None - self._ttfb_started: bool = False def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. @@ -166,20 +160,6 @@ class DeepgramSageMakerTTSService(TTSService): await super().cancel(frame) await self._disconnect() - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process frames with special handling for LLM response end. - - Args: - frame: The frame to process. - direction: The direction of frame processing. - """ - await super().process_frame(frame, direction) - - if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)): - await self.flush_audio() - elif isinstance(frame, BotStoppedSpeakingFrame): - self._ttfb_started = False - async def _connect(self): """Connect to the SageMaker endpoint and start the BiDi session. @@ -301,13 +281,14 @@ class DeepgramSageMakerTTSService(TTSService): except (UnicodeDecodeError, json.JSONDecodeError): # Not JSON — treat as raw audio bytes await self.stop_ttfb_metrics() + context_id = self.get_active_audio_context_id() frame = TTSAudioRawFrame( payload, self.sample_rate, 1, - context_id=self._context_id, + context_id=context_id, ) - await self.push_frame(frame) + await self.append_to_audio_context(context_id, frame) except asyncio.CancelledError: logger.debug("TTS response processor cancelled") @@ -316,15 +297,13 @@ class DeepgramSageMakerTTSService(TTSService): finally: logger.debug("TTS response processor stopped") - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - """Handle interruption by sending Clear message to Deepgram. + async def on_audio_context_interrupted(self, context_id: str): + """Called when an audio context is cancelled due to an interruption. - The Clear message will clear Deepgram's internal text buffer and stop - sending audio, allowing for a new response to be generated. + Args: + context_id: The ID of the audio context that was interrupted, or + ``None`` if no context was active at the time. """ - await super()._handle_interruption(frame, direction) - self._ttfb_started = False - if self._client and self._client.is_active: try: await self._client.send_json({"type": "Clear"}) @@ -356,19 +335,8 @@ class DeepgramSageMakerTTSService(TTSService): the response processor). """ logger.debug(f"{self}: Generating TTS [{text}]") - try: - if not self.audio_context_available(context_id): - await self.create_audio_context(context_id) - if not self._ttfb_started: - await self.start_ttfb_metrics() - self._ttfb_started = True - yield TTSStartedFrame(context_id=context_id) - self._context_id = context_id - await self._client.send_json({"type": "Speak", "text": text}) - yield None - except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}")