Merge pull request #3902 from pipecat-ai/aleix/deepgram-sagemaker-move

Move Deepgram SageMaker modules to sagemaker/ subpackage
This commit is contained in:
Aleix Conchillo Flaqué
2026-03-02 15:25:17 -08:00
committed by GitHub
9 changed files with 831 additions and 793 deletions

View File

@@ -0,0 +1 @@
- Moved `pipecat.services.deepgram.stt_sagemaker` and `pipecat.services.deepgram.tts_sagemaker` to `pipecat.services.deepgram.sagemaker.stt` and `pipecat.services.deepgram.sagemaker.tts`. The old import paths still work but emit a `DeprecationWarning`.

View File

@@ -23,8 +23,8 @@ from pipecat.processors.aggregators.llm_response_universal import (
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.deepgram.stt_sagemaker import DeepgramSageMakerSTTService
from pipecat.services.deepgram.tts_sagemaker import DeepgramSageMakerTTSService
from pipecat.services.deepgram.sagemaker.stt import DeepgramSageMakerSTTService
from pipecat.services.deepgram.sagemaker.tts import DeepgramSageMakerTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams

View File

@@ -24,7 +24,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt_sagemaker import (
from pipecat.services.deepgram.sagemaker.stt import (
DeepgramSageMakerSTTService,
DeepgramSageMakerSTTSettings,
)

View File

@@ -22,11 +22,11 @@ from pipecat.processors.aggregators.llm_response_universal import (
)
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.deepgram.tts_sagemaker import (
from pipecat.services.deepgram.sagemaker.tts import (
DeepgramSageMakerTTSService,
DeepgramSageMakerTTSSettings,
)
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams

View File

@@ -9,6 +9,7 @@ import sys
from pipecat.services import DeprecatedModuleProxy
from .flux import *
from .sagemaker import *
from .stt import *
from .tts import *

View File

@@ -0,0 +1,448 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Deepgram speech-to-text service for AWS SageMaker.
This module provides a Pipecat STT service that connects to Deepgram models
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
low-latency real-time transcription with support for interim results, multiple
languages, and various Deepgram features.
"""
import asyncio
import json
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
from pipecat.services.deepgram.stt import _DeepgramSTTSettingsBase
from pipecat.services.settings import STTSettings
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
from deepgram import LiveOptions
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use DeepgramSageMakerSTTService, you need to `pip install pipecat-ai[deepgram,sagemaker]`."
)
raise Exception(f"Missing module: {e}")
@dataclass
class DeepgramSageMakerSTTSettings(_DeepgramSTTSettingsBase):
"""Settings for the Deepgram SageMaker STT service.
See ``_DeepgramSTTSettingsBase`` for full documentation.
"""
pass
class DeepgramSageMakerSTTService(STTService):
"""Deepgram speech-to-text service for AWS SageMaker.
Provides real-time speech recognition using Deepgram models deployed on
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
transcription with support for interim results, speaker diarization, and
multiple languages.
Requirements:
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
- A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
- Deepgram SDK for LiveOptions configuration
Example::
stt = DeepgramSageMakerSTTService(
endpoint_name="my-deepgram-endpoint",
region="us-east-2",
live_options=LiveOptions(
model="nova-3",
language="en",
interim_results=True,
punctuate=True,
),
)
"""
_settings: DeepgramSageMakerSTTSettings
def __init__(
self,
*,
endpoint_name: str,
region: str,
sample_rate: Optional[int] = None,
live_options: Optional[LiveOptions] = None,
ttfs_p99_latency: Optional[float] = DEEPGRAM_SAGEMAKER_TTFS_P99,
**kwargs,
):
"""Initialize the Deepgram SageMaker STT service.
Args:
endpoint_name: Name of the SageMaker endpoint with Deepgram model
deployed (e.g., "my-deepgram-nova-3-endpoint").
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
sample_rate: Audio sample rate in Hz. If None, uses value from
live_options or defaults to the value from StartFrame.
live_options: Deepgram LiveOptions configuration. Treated as a
delta from a set of sensible defaults — only the fields you
set are overridden; all others keep their default values.
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
**kwargs: Additional arguments passed to the parent STTService.
"""
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
default_options = LiveOptions(
encoding="linear16",
language=Language.EN,
model="nova-3",
channels=1,
interim_results=True,
punctuate=True,
)
settings = DeepgramSageMakerSTTSettings(
model=default_options.model,
language=default_options.language,
live_options=default_options,
)
if live_options:
settings._merge_live_options_delta(live_options)
super().__init__(
sample_rate=sample_rate,
ttfs_p99_latency=ttfs_p99_latency,
settings=settings,
**kwargs,
)
self._endpoint_name = endpoint_name
self._region = region
self._client: Optional[SageMakerBidiClient] = None
self._response_task: Optional[asyncio.Task] = None
self._keepalive_task: Optional[asyncio.Task] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Deepgram SageMaker service supports metrics generation.
"""
return True
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
"""Apply a settings delta and warn about unhandled changes."""
changed = await super()._update_settings(delta)
if not changed:
return changed
# TODO: someday we could reconnect here to apply updated settings.
# Code might look something like the below:
# await self._disconnect()
# await self._connect()
self._warn_unhandled_updated_settings(changed)
return changed
async def start(self, frame: StartFrame):
"""Start the Deepgram SageMaker STT service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Deepgram SageMaker STT service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Deepgram SageMaker STT service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Send audio data to Deepgram for transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
Frame: None (transcription results come via BiDi stream callbacks).
"""
if self._client and self._client.is_active:
try:
await self._client.send_audio_chunk(audio)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield None
async def _connect(self):
"""Connect to the SageMaker endpoint and start the BiDi session.
Builds the Deepgram query string from settings, creates the BiDi client,
starts the streaming session, and launches background tasks for processing
responses and sending KeepAlive messages.
"""
logger.debug("Connecting to Deepgram on SageMaker...")
live_options = LiveOptions(
**{**self._settings.live_options.to_dict(), "sample_rate": self.sample_rate}
)
# Build query string from live_options, converting booleans to strings
query_params = {}
for key, value in live_options.to_dict().items():
if value is not None:
# Convert boolean values to lowercase strings for Deepgram API
if isinstance(value, bool):
query_params[key] = str(value).lower()
else:
query_params[key] = str(value)
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
# Create BiDi client
self._client = SageMakerBidiClient(
endpoint_name=self._endpoint_name,
region=self._region,
model_invocation_path="v1/listen",
model_query_string=query_string,
)
try:
# Start the session
await self._client.start_session()
# Start processing responses in the background
self._response_task = self.create_task(self._process_responses())
# Start keepalive task to maintain connection
self._keepalive_task = self.create_task(self._send_keepalive())
logger.debug("Connected to Deepgram on SageMaker")
await self._call_event_handler("on_connected")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
await self._call_event_handler("on_connection_error", str(e))
async def _disconnect(self):
"""Disconnect from the SageMaker endpoint.
Sends a CloseStream message to Deepgram, cancels background tasks
(KeepAlive and response processing), and closes the BiDi session.
Safe to call multiple times.
"""
if self._client and self._client.is_active:
logger.debug("Disconnecting from Deepgram on SageMaker...")
# Send CloseStream message to Deepgram
try:
await self._client.send_json({"type": "CloseStream"})
except Exception as e:
logger.warning(f"Failed to send CloseStream message: {e}")
# Cancel keepalive task
if self._keepalive_task and not self._keepalive_task.done():
await self.cancel_task(self._keepalive_task)
# Cancel response processing task
if self._response_task and not self._response_task.done():
await self.cancel_task(self._response_task)
# Close the BiDi session
await self._client.close_session()
logger.debug("Disconnected from Deepgram on SageMaker")
await self._call_event_handler("on_disconnected")
async def _send_keepalive(self):
"""Send periodic KeepAlive messages to maintain the connection.
Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
connection is active. This prevents the connection from timing out during
periods of silence.
"""
while self._client and self._client.is_active:
await asyncio.sleep(5)
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "KeepAlive"})
except Exception as e:
logger.warning(f"Failed to send KeepAlive: {e}")
async def _process_responses(self):
"""Process streaming responses from Deepgram on SageMaker.
Continuously receives responses from the BiDi stream, decodes the payload,
parses JSON responses from Deepgram, and processes transcription results.
Runs as a background task until the connection is closed or cancelled.
"""
try:
while self._client and self._client.is_active:
result = await self._client.receive_response()
if result is None:
break
# Check if this is a PayloadPart with bytes
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
if result.value.bytes_:
response_data = result.value.bytes_.decode("utf-8")
try:
# Parse JSON response from Deepgram
parsed = json.loads(response_data)
# Extract and process transcript if available
if "channel" in parsed:
await self._handle_transcript_response(parsed)
except json.JSONDecodeError:
logger.warning(f"Non-JSON response: {response_data}")
except asyncio.CancelledError:
logger.debug("Response processor cancelled")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
logger.debug("Response processor stopped")
async def _handle_transcript_response(self, parsed: dict):
"""Handle a transcript response from Deepgram.
Extracts the transcript text, determines if it's final or interim, extracts
language information, and pushes the appropriate frame (TranscriptionFrame
or InterimTranscriptionFrame) downstream.
Args:
parsed: The parsed JSON response from Deepgram containing channel,
alternatives, transcript, and metadata.
"""
alternatives = parsed.get("channel", {}).get("alternatives", [])
if not alternatives or not alternatives[0].get("transcript"):
return
transcript = alternatives[0]["transcript"]
if not transcript.strip():
return
is_final = parsed.get("is_final", False)
# Extract language if available
language = None
if alternatives[0].get("languages"):
language = alternatives[0]["languages"][0]
language = Language(language)
if is_final:
# Check if this response is from a finalize() call.
# Only mark as finalized when both we requested it AND Deepgram confirms it.
from_finalize = parsed.get("from_finalize", False)
if from_finalize:
self.confirm_finalize()
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=parsed,
)
)
await self._handle_transcription(transcript, is_final, language)
await self.stop_processing_metrics()
else:
# Interim transcription
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=parsed,
)
)
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing.
This method is decorated with @traced_stt for observability and tracing
integration. The actual transcription processing is handled by the parent
class and observers.
Args:
transcript: The transcribed text.
is_final: Whether this is a final transcription result.
language: The detected language of the transcription, if available.
"""
pass
async def _start_metrics(self):
"""Start processing metrics collection."""
await self.start_processing_metrics()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with Deepgram SageMaker-specific handling.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
if isinstance(frame, VADUserStartedSpeakingFrame):
await self._start_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
# https://developers.deepgram.com/docs/finalize
# Mark that we're awaiting a from_finalize response
self.request_finalize()
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "Finalize"})
except Exception as e:
logger.warning(f"Error sending Finalize message: {e}")
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")

View File

@@ -0,0 +1,360 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Deepgram text-to-speech service for AWS SageMaker.
This module provides a Pipecat TTS service that connects to Deepgram models
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
low-latency real-time speech synthesis with support for interruptions and
streaming audio output.
"""
import asyncio
import json
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Optional
from loguru import logger
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import TTSService
from pipecat.utils.tracing.service_decorators import traced_tts
@dataclass
class DeepgramSageMakerTTSSettings(TTSSettings):
"""Settings for Deepgram SageMaker TTS service.
Parameters:
encoding: Audio encoding format (e.g. "linear16").
"""
encoding: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
class DeepgramSageMakerTTSService(TTSService):
"""Deepgram text-to-speech service for AWS SageMaker.
Provides real-time speech synthesis using Deepgram models deployed on
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
audio generation with support for interruptions via the Clear message.
Requirements:
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
- A deployed SageMaker endpoint with Deepgram TTS model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
- ``pipecat-ai[sagemaker]`` installed
Example::
tts = DeepgramSageMakerTTSService(
endpoint_name="my-deepgram-tts-endpoint",
region="us-east-2",
voice="aura-2-helena-en",
)
"""
_settings: DeepgramSageMakerTTSSettings
def __init__(
self,
*,
endpoint_name: str,
region: str,
voice: str = "aura-2-helena-en",
sample_rate: Optional[int] = None,
encoding: str = "linear16",
**kwargs,
):
"""Initialize the Deepgram SageMaker TTS service.
Args:
endpoint_name: Name of the SageMaker endpoint with Deepgram TTS model
deployed (e.g., "my-deepgram-tts-endpoint").
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
sample_rate: Audio sample rate in Hz. If None, uses the value from StartFrame.
encoding: Audio encoding format. Defaults to "linear16".
**kwargs: Additional arguments passed to the parent TTSService.
"""
super().__init__(
sample_rate=sample_rate,
push_stop_frames=True,
pause_frame_processing=True,
append_trailing_space=True,
settings=DeepgramSageMakerTTSSettings(
model=voice,
voice=voice,
language=None,
encoding=encoding,
),
**kwargs,
)
self._endpoint_name = endpoint_name
self._region = region
self._client: Optional[SageMakerBidiClient] = None
self._response_task: Optional[asyncio.Task] = None
self._context_id: Optional[str] = None
self._ttfb_started: bool = False
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Deepgram SageMaker TTS service supports metrics generation.
"""
return True
async def start(self, frame: StartFrame):
"""Start the Deepgram SageMaker TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Deepgram SageMaker TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Deepgram SageMaker TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with special handling for LLM response end.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
elif isinstance(frame, BotStoppedSpeakingFrame):
self._ttfb_started = False
async def _connect(self):
"""Connect to the SageMaker endpoint and start the BiDi session.
Builds the Deepgram TTS query string, creates the BiDi client,
starts the streaming session, and launches a background task for processing
responses.
"""
logger.debug("Connecting to Deepgram TTS on SageMaker...")
query_string = (
f"model={self._settings.voice}&encoding={self._settings.encoding}"
f"&sample_rate={self.sample_rate}"
)
self._client = SageMakerBidiClient(
endpoint_name=self._endpoint_name,
region=self._region,
model_invocation_path="v1/speak",
model_query_string=query_string,
)
try:
await self._client.start_session()
self._response_task = self.create_task(self._process_responses())
logger.debug("Connected to Deepgram TTS on SageMaker")
await self._call_event_handler("on_connected")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
await self._call_event_handler("on_connection_error", str(e))
async def _disconnect(self):
"""Disconnect from the SageMaker endpoint.
Sends a Close message to Deepgram, cancels the response processing task,
and closes the BiDi session. Safe to call multiple times.
"""
if self._client and self._client.is_active:
logger.debug("Disconnecting from Deepgram TTS on SageMaker...")
try:
await self._client.send_json({"type": "Close"})
except Exception as e:
logger.warning(f"Failed to send Close message: {e}")
if self._response_task and not self._response_task.done():
await self.cancel_task(self._response_task)
await self._client.close_session()
logger.debug("Disconnected from Deepgram TTS on SageMaker")
await self._call_event_handler("on_disconnected")
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
"""Apply a settings delta and reconnect if necessary.
Since all settings are part of the SageMaker session query string,
any setting change requires reconnecting to apply the new values.
"""
changed = await super()._update_settings(delta)
if not changed:
return changed
# Deepgram uses voice as the model, so keep them in sync for metrics
if "voice" in changed:
self._settings.model = self._settings.voice
self._sync_model_name_to_metrics()
# TODO: someday we could reconnect here to apply updated settings.
# Code might look something like the below:
# await self._disconnect()
# await self._connect()
self._warn_unhandled_updated_settings(changed)
return changed
async def _process_responses(self):
"""Process streaming responses from Deepgram TTS on SageMaker.
Continuously receives responses from the BiDi stream. Attempts to decode
each payload as UTF-8 JSON for control messages (Flushed, Cleared, Metadata,
Warning). If decoding fails, treats the payload as raw audio bytes and pushes
a TTSAudioRawFrame downstream.
"""
try:
while self._client and self._client.is_active:
result = await self._client.receive_response()
if result is None:
break
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
if result.value.bytes_:
payload = result.value.bytes_
# Try to decode as JSON control message first
try:
response_data = payload.decode("utf-8")
parsed = json.loads(response_data)
msg_type = parsed.get("type")
if msg_type == "Metadata":
logger.trace(f"Received metadata: {parsed}")
elif msg_type == "Flushed":
logger.trace(f"Received Flushed: {parsed}")
elif msg_type == "Cleared":
logger.trace(f"Received Cleared: {parsed}")
elif msg_type == "Warning":
logger.warning(
f"{self} warning: "
f"{parsed.get('description', 'Unknown warning')}"
)
else:
logger.debug(f"Received unknown message type: {parsed}")
except (UnicodeDecodeError, json.JSONDecodeError):
# Not JSON — treat as raw audio bytes
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
payload,
self.sample_rate,
1,
context_id=self._context_id,
)
await self.push_frame(frame)
except asyncio.CancelledError:
logger.debug("TTS response processor cancelled")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
logger.debug("TTS response processor stopped")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by sending Clear message to Deepgram.
The Clear message will clear Deepgram's internal text buffer and stop
sending audio, allowing for a new response to be generated.
"""
await super()._handle_interruption(frame, direction)
self._ttfb_started = False
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "Clear"})
except Exception as e:
logger.error(f"{self} error sending Clear message: {e}")
async def flush_audio(self):
"""Flush any pending audio synthesis by sending Flush command.
This should be called when the LLM finishes a complete response to force
generation of audio from Deepgram's internal text buffer.
"""
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "Flush"})
except Exception as e:
logger.error(f"{self} error sending Flush message: {e}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Deepgram TTS on SageMaker.
Args:
text: The text to synthesize into speech.
context_id: The context ID for tracking audio frames.
Yields:
Frame: TTSStartedFrame, then None (audio comes asynchronously via
the response processor).
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
if not self._ttfb_started:
await self.start_ttfb_metrics()
self._ttfb_started = True
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
self._context_id = context_id
await self._client.send_json({"type": "Speak", "text": text})
yield None
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -4,445 +4,15 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Deepgram speech-to-text service for AWS SageMaker.
"""Deprecated: use ``pipecat.services.deepgram.sagemaker.stt`` instead."""
This module provides a Pipecat STT service that connects to Deepgram models
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
low-latency real-time transcription with support for interim results, multiple
languages, and various Deepgram features.
"""
import warnings
import asyncio
import json
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
warnings.warn(
"Module `pipecat.services.deepgram.stt_sagemaker` is deprecated, "
"use `pipecat.services.deepgram.sagemaker.stt` instead.",
DeprecationWarning,
stacklevel=2,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
from pipecat.services.deepgram.stt import _DeepgramSTTSettingsBase
from pipecat.services.settings import STTSettings
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
from deepgram import LiveOptions
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use DeepgramSageMakerSTTService, you need to `pip install pipecat-ai[deepgram,sagemaker]`."
)
raise Exception(f"Missing module: {e}")
@dataclass
class DeepgramSageMakerSTTSettings(_DeepgramSTTSettingsBase):
"""Settings for the Deepgram SageMaker STT service.
See ``_DeepgramSTTSettingsBase`` for full documentation.
"""
pass
class DeepgramSageMakerSTTService(STTService):
"""Deepgram speech-to-text service for AWS SageMaker.
Provides real-time speech recognition using Deepgram models deployed on
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
transcription with support for interim results, speaker diarization, and
multiple languages.
Requirements:
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
- A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
- Deepgram SDK for LiveOptions configuration
Example::
stt = DeepgramSageMakerSTTService(
endpoint_name="my-deepgram-endpoint",
region="us-east-2",
live_options=LiveOptions(
model="nova-3",
language="en",
interim_results=True,
punctuate=True,
),
)
"""
_settings: DeepgramSageMakerSTTSettings
def __init__(
self,
*,
endpoint_name: str,
region: str,
sample_rate: Optional[int] = None,
live_options: Optional[LiveOptions] = None,
ttfs_p99_latency: Optional[float] = DEEPGRAM_SAGEMAKER_TTFS_P99,
**kwargs,
):
"""Initialize the Deepgram SageMaker STT service.
Args:
endpoint_name: Name of the SageMaker endpoint with Deepgram model
deployed (e.g., "my-deepgram-nova-3-endpoint").
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
sample_rate: Audio sample rate in Hz. If None, uses value from
live_options or defaults to the value from StartFrame.
live_options: Deepgram LiveOptions configuration. Treated as a
delta from a set of sensible defaults — only the fields you
set are overridden; all others keep their default values.
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
**kwargs: Additional arguments passed to the parent STTService.
"""
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
default_options = LiveOptions(
encoding="linear16",
language=Language.EN,
model="nova-3",
channels=1,
interim_results=True,
punctuate=True,
)
settings = DeepgramSageMakerSTTSettings(
model=default_options.model,
language=default_options.language,
live_options=default_options,
)
if live_options:
settings._merge_live_options_delta(live_options)
super().__init__(
sample_rate=sample_rate,
ttfs_p99_latency=ttfs_p99_latency,
settings=settings,
**kwargs,
)
self._endpoint_name = endpoint_name
self._region = region
self._client: Optional[SageMakerBidiClient] = None
self._response_task: Optional[asyncio.Task] = None
self._keepalive_task: Optional[asyncio.Task] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Deepgram SageMaker service supports metrics generation.
"""
return True
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
"""Apply a settings delta and warn about unhandled changes."""
changed = await super()._update_settings(delta)
if not changed:
return changed
# TODO: someday we could reconnect here to apply updated settings.
# Code might look something like the below:
# await self._disconnect()
# await self._connect()
self._warn_unhandled_updated_settings(changed)
return changed
async def start(self, frame: StartFrame):
"""Start the Deepgram SageMaker STT service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Deepgram SageMaker STT service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Deepgram SageMaker STT service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Send audio data to Deepgram for transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
Frame: None (transcription results come via BiDi stream callbacks).
"""
if self._client and self._client.is_active:
try:
await self._client.send_audio_chunk(audio)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield None
async def _connect(self):
"""Connect to the SageMaker endpoint and start the BiDi session.
Builds the Deepgram query string from settings, creates the BiDi client,
starts the streaming session, and launches background tasks for processing
responses and sending KeepAlive messages.
"""
logger.debug("Connecting to Deepgram on SageMaker...")
live_options = LiveOptions(
**{**self._settings.live_options.to_dict(), "sample_rate": self.sample_rate}
)
# Build query string from live_options, converting booleans to strings
query_params = {}
for key, value in live_options.to_dict().items():
if value is not None:
# Convert boolean values to lowercase strings for Deepgram API
if isinstance(value, bool):
query_params[key] = str(value).lower()
else:
query_params[key] = str(value)
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
# Create BiDi client
self._client = SageMakerBidiClient(
endpoint_name=self._endpoint_name,
region=self._region,
model_invocation_path="v1/listen",
model_query_string=query_string,
)
try:
# Start the session
await self._client.start_session()
# Start processing responses in the background
self._response_task = self.create_task(self._process_responses())
# Start keepalive task to maintain connection
self._keepalive_task = self.create_task(self._send_keepalive())
logger.debug("Connected to Deepgram on SageMaker")
await self._call_event_handler("on_connected")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
await self._call_event_handler("on_connection_error", str(e))
async def _disconnect(self):
"""Disconnect from the SageMaker endpoint.
Sends a CloseStream message to Deepgram, cancels background tasks
(KeepAlive and response processing), and closes the BiDi session.
Safe to call multiple times.
"""
if self._client and self._client.is_active:
logger.debug("Disconnecting from Deepgram on SageMaker...")
# Send CloseStream message to Deepgram
try:
await self._client.send_json({"type": "CloseStream"})
except Exception as e:
logger.warning(f"Failed to send CloseStream message: {e}")
# Cancel keepalive task
if self._keepalive_task and not self._keepalive_task.done():
await self.cancel_task(self._keepalive_task)
# Cancel response processing task
if self._response_task and not self._response_task.done():
await self.cancel_task(self._response_task)
# Close the BiDi session
await self._client.close_session()
logger.debug("Disconnected from Deepgram on SageMaker")
await self._call_event_handler("on_disconnected")
async def _send_keepalive(self):
"""Send periodic KeepAlive messages to maintain the connection.
Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
connection is active. This prevents the connection from timing out during
periods of silence.
"""
while self._client and self._client.is_active:
await asyncio.sleep(5)
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "KeepAlive"})
except Exception as e:
logger.warning(f"Failed to send KeepAlive: {e}")
async def _process_responses(self):
"""Process streaming responses from Deepgram on SageMaker.
Continuously receives responses from the BiDi stream, decodes the payload,
parses JSON responses from Deepgram, and processes transcription results.
Runs as a background task until the connection is closed or cancelled.
"""
try:
while self._client and self._client.is_active:
result = await self._client.receive_response()
if result is None:
break
# Check if this is a PayloadPart with bytes
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
if result.value.bytes_:
response_data = result.value.bytes_.decode("utf-8")
try:
# Parse JSON response from Deepgram
parsed = json.loads(response_data)
# Extract and process transcript if available
if "channel" in parsed:
await self._handle_transcript_response(parsed)
except json.JSONDecodeError:
logger.warning(f"Non-JSON response: {response_data}")
except asyncio.CancelledError:
logger.debug("Response processor cancelled")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
logger.debug("Response processor stopped")
async def _handle_transcript_response(self, parsed: dict):
"""Handle a transcript response from Deepgram.
Extracts the transcript text, determines if it's final or interim, extracts
language information, and pushes the appropriate frame (TranscriptionFrame
or InterimTranscriptionFrame) downstream.
Args:
parsed: The parsed JSON response from Deepgram containing channel,
alternatives, transcript, and metadata.
"""
alternatives = parsed.get("channel", {}).get("alternatives", [])
if not alternatives or not alternatives[0].get("transcript"):
return
transcript = alternatives[0]["transcript"]
if not transcript.strip():
return
is_final = parsed.get("is_final", False)
# Extract language if available
language = None
if alternatives[0].get("languages"):
language = alternatives[0]["languages"][0]
language = Language(language)
if is_final:
# Check if this response is from a finalize() call.
# Only mark as finalized when both we requested it AND Deepgram confirms it.
from_finalize = parsed.get("from_finalize", False)
if from_finalize:
self.confirm_finalize()
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=parsed,
)
)
await self._handle_transcription(transcript, is_final, language)
await self.stop_processing_metrics()
else:
# Interim transcription
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=parsed,
)
)
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing.
This method is decorated with @traced_stt for observability and tracing
integration. The actual transcription processing is handled by the parent
class and observers.
Args:
transcript: The transcribed text.
is_final: Whether this is a final transcription result.
language: The detected language of the transcription, if available.
"""
pass
async def _start_metrics(self):
"""Start processing metrics collection."""
await self.start_processing_metrics()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with Deepgram SageMaker-specific handling.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
if isinstance(frame, VADUserStartedSpeakingFrame):
await self._start_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
# https://developers.deepgram.com/docs/finalize
# Mark that we're awaiting a from_finalize response
self.request_finalize()
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "Finalize"})
except Exception as e:
logger.warning(f"Error sending Finalize message: {e}")
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")
from pipecat.services.deepgram.sagemaker.stt import * # noqa: E402, F401, F403

View File

@@ -4,357 +4,15 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Deepgram text-to-speech service for AWS SageMaker.
"""Deprecated: use ``pipecat.services.deepgram.sagemaker.tts`` instead."""
This module provides a Pipecat TTS service that connects to Deepgram models
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
low-latency real-time speech synthesis with support for interruptions and
streaming audio output.
"""
import warnings
import asyncio
import json
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Optional
from loguru import logger
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
warnings.warn(
"Module `pipecat.services.deepgram.tts_sagemaker` is deprecated, "
"use `pipecat.services.deepgram.sagemaker.tts` instead.",
DeprecationWarning,
stacklevel=2,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import TTSService
from pipecat.utils.tracing.service_decorators import traced_tts
@dataclass
class DeepgramSageMakerTTSSettings(TTSSettings):
"""Settings for Deepgram SageMaker TTS service.
Parameters:
encoding: Audio encoding format (e.g. "linear16").
"""
encoding: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
class DeepgramSageMakerTTSService(TTSService):
"""Deepgram text-to-speech service for AWS SageMaker.
Provides real-time speech synthesis using Deepgram models deployed on
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
audio generation with support for interruptions via the Clear message.
Requirements:
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
- A deployed SageMaker endpoint with Deepgram TTS model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
- ``pipecat-ai[sagemaker]`` installed
Example::
tts = DeepgramSageMakerTTSService(
endpoint_name="my-deepgram-tts-endpoint",
region="us-east-2",
voice="aura-2-helena-en",
)
"""
_settings: DeepgramSageMakerTTSSettings
def __init__(
self,
*,
endpoint_name: str,
region: str,
voice: str = "aura-2-helena-en",
sample_rate: Optional[int] = None,
encoding: str = "linear16",
**kwargs,
):
"""Initialize the Deepgram SageMaker TTS service.
Args:
endpoint_name: Name of the SageMaker endpoint with Deepgram TTS model
deployed (e.g., "my-deepgram-tts-endpoint").
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
sample_rate: Audio sample rate in Hz. If None, uses the value from StartFrame.
encoding: Audio encoding format. Defaults to "linear16".
**kwargs: Additional arguments passed to the parent TTSService.
"""
super().__init__(
sample_rate=sample_rate,
push_stop_frames=True,
pause_frame_processing=True,
append_trailing_space=True,
settings=DeepgramSageMakerTTSSettings(
model=voice,
voice=voice,
language=None,
encoding=encoding,
),
**kwargs,
)
self._endpoint_name = endpoint_name
self._region = region
self._client: Optional[SageMakerBidiClient] = None
self._response_task: Optional[asyncio.Task] = None
self._context_id: Optional[str] = None
self._ttfb_started: bool = False
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Deepgram SageMaker TTS service supports metrics generation.
"""
return True
async def start(self, frame: StartFrame):
"""Start the Deepgram SageMaker TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Deepgram SageMaker TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Deepgram SageMaker TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with special handling for LLM response end.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
elif isinstance(frame, BotStoppedSpeakingFrame):
self._ttfb_started = False
async def _connect(self):
"""Connect to the SageMaker endpoint and start the BiDi session.
Builds the Deepgram TTS query string, creates the BiDi client,
starts the streaming session, and launches a background task for processing
responses.
"""
logger.debug("Connecting to Deepgram TTS on SageMaker...")
query_string = (
f"model={self._settings.voice}&encoding={self._settings.encoding}"
f"&sample_rate={self.sample_rate}"
)
self._client = SageMakerBidiClient(
endpoint_name=self._endpoint_name,
region=self._region,
model_invocation_path="v1/speak",
model_query_string=query_string,
)
try:
await self._client.start_session()
self._response_task = self.create_task(self._process_responses())
logger.debug("Connected to Deepgram TTS on SageMaker")
await self._call_event_handler("on_connected")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
await self._call_event_handler("on_connection_error", str(e))
async def _disconnect(self):
"""Disconnect from the SageMaker endpoint.
Sends a Close message to Deepgram, cancels the response processing task,
and closes the BiDi session. Safe to call multiple times.
"""
if self._client and self._client.is_active:
logger.debug("Disconnecting from Deepgram TTS on SageMaker...")
try:
await self._client.send_json({"type": "Close"})
except Exception as e:
logger.warning(f"Failed to send Close message: {e}")
if self._response_task and not self._response_task.done():
await self.cancel_task(self._response_task)
await self._client.close_session()
logger.debug("Disconnected from Deepgram TTS on SageMaker")
await self._call_event_handler("on_disconnected")
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
"""Apply a settings delta and reconnect if necessary.
Since all settings are part of the SageMaker session query string,
any setting change requires reconnecting to apply the new values.
"""
changed = await super()._update_settings(delta)
if not changed:
return changed
# Deepgram uses voice as the model, so keep them in sync for metrics
if "voice" in changed:
self._settings.model = self._settings.voice
self._sync_model_name_to_metrics()
# TODO: someday we could reconnect here to apply updated settings.
# Code might look something like the below:
# await self._disconnect()
# await self._connect()
self._warn_unhandled_updated_settings(changed)
return changed
async def _process_responses(self):
"""Process streaming responses from Deepgram TTS on SageMaker.
Continuously receives responses from the BiDi stream. Attempts to decode
each payload as UTF-8 JSON for control messages (Flushed, Cleared, Metadata,
Warning). If decoding fails, treats the payload as raw audio bytes and pushes
a TTSAudioRawFrame downstream.
"""
try:
while self._client and self._client.is_active:
result = await self._client.receive_response()
if result is None:
break
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
if result.value.bytes_:
payload = result.value.bytes_
# Try to decode as JSON control message first
try:
response_data = payload.decode("utf-8")
parsed = json.loads(response_data)
msg_type = parsed.get("type")
if msg_type == "Metadata":
logger.trace(f"Received metadata: {parsed}")
elif msg_type == "Flushed":
logger.trace(f"Received Flushed: {parsed}")
elif msg_type == "Cleared":
logger.trace(f"Received Cleared: {parsed}")
elif msg_type == "Warning":
logger.warning(
f"{self} warning: "
f"{parsed.get('description', 'Unknown warning')}"
)
else:
logger.debug(f"Received unknown message type: {parsed}")
except (UnicodeDecodeError, json.JSONDecodeError):
# Not JSON — treat as raw audio bytes
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
payload,
self.sample_rate,
1,
context_id=self._context_id,
)
await self.push_frame(frame)
except asyncio.CancelledError:
logger.debug("TTS response processor cancelled")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
logger.debug("TTS response processor stopped")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by sending Clear message to Deepgram.
The Clear message will clear Deepgram's internal text buffer and stop
sending audio, allowing for a new response to be generated.
"""
await super()._handle_interruption(frame, direction)
self._ttfb_started = False
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "Clear"})
except Exception as e:
logger.error(f"{self} error sending Clear message: {e}")
async def flush_audio(self):
"""Flush any pending audio synthesis by sending Flush command.
This should be called when the LLM finishes a complete response to force
generation of audio from Deepgram's internal text buffer.
"""
if self._client and self._client.is_active:
try:
await self._client.send_json({"type": "Flush"})
except Exception as e:
logger.error(f"{self} error sending Flush message: {e}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Deepgram TTS on SageMaker.
Args:
text: The text to synthesize into speech.
context_id: The context ID for tracking audio frames.
Yields:
Frame: TTSStartedFrame, then None (audio comes asynchronously via
the response processor).
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
if not self._ttfb_started:
await self.start_ttfb_metrics()
self._ttfb_started = True
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
self._context_id = context_id
await self._client.send_json({"type": "Speak", "text": text})
yield None
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
from pipecat.services.deepgram.sagemaker.tts import * # noqa: E402, F401, F403