Merge pull request #3902 from pipecat-ai/aleix/deepgram-sagemaker-move
Move Deepgram SageMaker modules to sagemaker/ subpackage
This commit is contained in:
1
changelog/3902.changed.md
Normal file
1
changelog/3902.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Moved `pipecat.services.deepgram.stt_sagemaker` and `pipecat.services.deepgram.tts_sagemaker` to `pipecat.services.deepgram.sagemaker.stt` and `pipecat.services.deepgram.sagemaker.tts`. The old import paths still work but emit a `DeprecationWarning`.
|
||||
@@ -23,8 +23,8 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.deepgram.stt_sagemaker import DeepgramSageMakerSTTService
|
||||
from pipecat.services.deepgram.tts_sagemaker import DeepgramSageMakerTTSService
|
||||
from pipecat.services.deepgram.sagemaker.stt import DeepgramSageMakerSTTService
|
||||
from pipecat.services.deepgram.sagemaker.tts import DeepgramSageMakerTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
@@ -24,7 +24,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt_sagemaker import (
|
||||
from pipecat.services.deepgram.sagemaker.stt import (
|
||||
DeepgramSageMakerSTTService,
|
||||
DeepgramSageMakerSTTSettings,
|
||||
)
|
||||
|
||||
@@ -22,11 +22,11 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.deepgram.tts_sagemaker import (
|
||||
from pipecat.services.deepgram.sagemaker.tts import (
|
||||
DeepgramSageMakerTTSService,
|
||||
DeepgramSageMakerTTSSettings,
|
||||
)
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
|
||||
@@ -9,6 +9,7 @@ import sys
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .flux import *
|
||||
from .sagemaker import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
|
||||
448
src/pipecat/services/deepgram/sagemaker/stt.py
Normal file
448
src/pipecat/services/deepgram/sagemaker/stt.py
Normal file
@@ -0,0 +1,448 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
This module provides a Pipecat STT service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time transcription with support for interim results, multiple
|
||||
languages, and various Deepgram features.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.deepgram.stt import _DeepgramSTTSettingsBase
|
||||
from pipecat.services.settings import STTSettings
|
||||
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from deepgram import LiveOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepgramSageMakerSTTService, you need to `pip install pipecat-ai[deepgram,sagemaker]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepgramSageMakerSTTSettings(_DeepgramSTTSettingsBase):
|
||||
"""Settings for the Deepgram SageMaker STT service.
|
||||
|
||||
See ``_DeepgramSTTSettingsBase`` for full documentation.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DeepgramSageMakerSTTService(STTService):
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech recognition using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
transcription with support for interim results, speaker diarization, and
|
||||
multiple languages.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- Deepgram SDK for LiveOptions configuration
|
||||
|
||||
Example::
|
||||
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name="my-deepgram-endpoint",
|
||||
region="us-east-2",
|
||||
live_options=LiveOptions(
|
||||
model="nova-3",
|
||||
language="en",
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
),
|
||||
)
|
||||
"""
|
||||
|
||||
_settings: DeepgramSageMakerSTTSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
live_options: Optional[LiveOptions] = None,
|
||||
ttfs_p99_latency: Optional[float] = DEEPGRAM_SAGEMAKER_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram model
|
||||
deployed (e.g., "my-deepgram-nova-3-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
sample_rate: Audio sample rate in Hz. If None, uses value from
|
||||
live_options or defaults to the value from StartFrame.
|
||||
live_options: Deepgram LiveOptions configuration. Treated as a
|
||||
delta from a set of sensible defaults — only the fields you
|
||||
set are overridden; all others keep their default values.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
|
||||
default_options = LiveOptions(
|
||||
encoding="linear16",
|
||||
language=Language.EN,
|
||||
model="nova-3",
|
||||
channels=1,
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
)
|
||||
|
||||
settings = DeepgramSageMakerSTTSettings(
|
||||
model=default_options.model,
|
||||
language=default_options.language,
|
||||
live_options=default_options,
|
||||
)
|
||||
if live_options:
|
||||
settings._merge_live_options_delta(live_options)
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
settings=settings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and warn about unhandled changes."""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
|
||||
return changed
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to Deepgram for transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
Frame: None (transcription results come via BiDi stream callbacks).
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_audio_chunk(audio)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram query string from settings, creates the BiDi client,
|
||||
starts the streaming session, and launches background tasks for processing
|
||||
responses and sending KeepAlive messages.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram on SageMaker...")
|
||||
|
||||
live_options = LiveOptions(
|
||||
**{**self._settings.live_options.to_dict(), "sample_rate": self.sample_rate}
|
||||
)
|
||||
|
||||
# Build query string from live_options, converting booleans to strings
|
||||
query_params = {}
|
||||
for key, value in live_options.to_dict().items():
|
||||
if value is not None:
|
||||
# Convert boolean values to lowercase strings for Deepgram API
|
||||
if isinstance(value, bool):
|
||||
query_params[key] = str(value).lower()
|
||||
else:
|
||||
query_params[key] = str(value)
|
||||
|
||||
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
|
||||
|
||||
# Create BiDi client
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/listen",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
# Start the session
|
||||
await self._client.start_session()
|
||||
|
||||
# Start processing responses in the background
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
# Start keepalive task to maintain connection
|
||||
self._keepalive_task = self.create_task(self._send_keepalive())
|
||||
|
||||
logger.debug("Connected to Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a CloseStream message to Deepgram, cancels background tasks
|
||||
(KeepAlive and response processing), and closes the BiDi session.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram on SageMaker...")
|
||||
|
||||
# Send CloseStream message to Deepgram
|
||||
try:
|
||||
await self._client.send_json({"type": "CloseStream"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send CloseStream message: {e}")
|
||||
|
||||
# Cancel keepalive task
|
||||
if self._keepalive_task and not self._keepalive_task.done():
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
|
||||
# Cancel response processing task
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
# Close the BiDi session
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _send_keepalive(self):
|
||||
"""Send periodic KeepAlive messages to maintain the connection.
|
||||
|
||||
Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
|
||||
connection is active. This prevents the connection from timing out during
|
||||
periods of silence.
|
||||
"""
|
||||
while self._client and self._client.is_active:
|
||||
await asyncio.sleep(5)
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "KeepAlive"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send KeepAlive: {e}")
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream, decodes the payload,
|
||||
parses JSON responses from Deepgram, and processes transcription results.
|
||||
Runs as a background task until the connection is closed or cancelled.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
# Check if this is a PayloadPart with bytes
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
response_data = result.value.bytes_.decode("utf-8")
|
||||
|
||||
try:
|
||||
# Parse JSON response from Deepgram
|
||||
parsed = json.loads(response_data)
|
||||
|
||||
# Extract and process transcript if available
|
||||
if "channel" in parsed:
|
||||
await self._handle_transcript_response(parsed)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Non-JSON response: {response_data}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("Response processor stopped")
|
||||
|
||||
async def _handle_transcript_response(self, parsed: dict):
|
||||
"""Handle a transcript response from Deepgram.
|
||||
|
||||
Extracts the transcript text, determines if it's final or interim, extracts
|
||||
language information, and pushes the appropriate frame (TranscriptionFrame
|
||||
or InterimTranscriptionFrame) downstream.
|
||||
|
||||
Args:
|
||||
parsed: The parsed JSON response from Deepgram containing channel,
|
||||
alternatives, transcript, and metadata.
|
||||
"""
|
||||
alternatives = parsed.get("channel", {}).get("alternatives", [])
|
||||
if not alternatives or not alternatives[0].get("transcript"):
|
||||
return
|
||||
|
||||
transcript = alternatives[0]["transcript"]
|
||||
if not transcript.strip():
|
||||
return
|
||||
|
||||
is_final = parsed.get("is_final", False)
|
||||
|
||||
# Extract language if available
|
||||
language = None
|
||||
if alternatives[0].get("languages"):
|
||||
language = alternatives[0]["languages"][0]
|
||||
language = Language(language)
|
||||
|
||||
if is_final:
|
||||
# Check if this response is from a finalize() call.
|
||||
# Only mark as finalized when both we requested it AND Deepgram confirms it.
|
||||
from_finalize = parsed.get("from_finalize", False)
|
||||
if from_finalize:
|
||||
self.confirm_finalize()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
# Interim transcription
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing.
|
||||
|
||||
This method is decorated with @traced_stt for observability and tracing
|
||||
integration. The actual transcription processing is handled by the parent
|
||||
class and observers.
|
||||
|
||||
Args:
|
||||
transcript: The transcribed text.
|
||||
is_final: Whether this is a final transcription result.
|
||||
language: The detected language of the transcription, if available.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _start_metrics(self):
|
||||
"""Start processing metrics collection."""
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Deepgram SageMaker-specific handling.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# https://developers.deepgram.com/docs/finalize
|
||||
# Mark that we're awaiting a from_finalize response
|
||||
self.request_finalize()
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Finalize"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending Finalize message: {e}")
|
||||
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")
|
||||
360
src/pipecat/services/deepgram/sagemaker/tts.py
Normal file
360
src/pipecat/services/deepgram/sagemaker/tts.py
Normal file
@@ -0,0 +1,360 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
|
||||
This module provides a Pipecat TTS service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time speech synthesis with support for interruptions and
|
||||
streaming audio output.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepgramSageMakerTTSSettings(TTSSettings):
|
||||
"""Settings for Deepgram SageMaker TTS service.
|
||||
|
||||
Parameters:
|
||||
encoding: Audio encoding format (e.g. "linear16").
|
||||
"""
|
||||
|
||||
encoding: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
|
||||
class DeepgramSageMakerTTSService(TTSService):
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech synthesis using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
audio generation with support for interruptions via the Clear message.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram TTS model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- ``pipecat-ai[sagemaker]`` installed
|
||||
|
||||
Example::
|
||||
|
||||
tts = DeepgramSageMakerTTSService(
|
||||
endpoint_name="my-deepgram-tts-endpoint",
|
||||
region="us-east-2",
|
||||
voice="aura-2-helena-en",
|
||||
)
|
||||
"""
|
||||
|
||||
_settings: DeepgramSageMakerTTSSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
voice: str = "aura-2-helena-en",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram TTS model
|
||||
deployed (e.g., "my-deepgram-tts-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses the value from StartFrame.
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to the parent TTSService.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
settings=DeepgramSageMakerTTSSettings(
|
||||
model=voice,
|
||||
voice=voice,
|
||||
language=None,
|
||||
encoding=encoding,
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._context_id: Optional[str] = None
|
||||
self._ttfb_started: bool = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for LLM response end.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._ttfb_started = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram TTS query string, creates the BiDi client,
|
||||
starts the streaming session, and launches a background task for processing
|
||||
responses.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram TTS on SageMaker...")
|
||||
|
||||
query_string = (
|
||||
f"model={self._settings.voice}&encoding={self._settings.encoding}"
|
||||
f"&sample_rate={self.sample_rate}"
|
||||
)
|
||||
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/speak",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._client.start_session()
|
||||
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
logger.debug("Connected to Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a Close message to Deepgram, cancels the response processing task,
|
||||
and closes the BiDi session. Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram TTS on SageMaker...")
|
||||
|
||||
try:
|
||||
await self._client.send_json({"type": "Close"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send Close message: {e}")
|
||||
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and reconnect if necessary.
|
||||
|
||||
Since all settings are part of the SageMaker session query string,
|
||||
any setting change requires reconnecting to apply the new values.
|
||||
"""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# Deepgram uses voice as the model, so keep them in sync for metrics
|
||||
if "voice" in changed:
|
||||
self._settings.model = self._settings.voice
|
||||
self._sync_model_name_to_metrics()
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
|
||||
return changed
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram TTS on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream. Attempts to decode
|
||||
each payload as UTF-8 JSON for control messages (Flushed, Cleared, Metadata,
|
||||
Warning). If decoding fails, treats the payload as raw audio bytes and pushes
|
||||
a TTSAudioRawFrame downstream.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
payload = result.value.bytes_
|
||||
|
||||
# Try to decode as JSON control message first
|
||||
try:
|
||||
response_data = payload.decode("utf-8")
|
||||
parsed = json.loads(response_data)
|
||||
msg_type = parsed.get("type")
|
||||
|
||||
if msg_type == "Metadata":
|
||||
logger.trace(f"Received metadata: {parsed}")
|
||||
elif msg_type == "Flushed":
|
||||
logger.trace(f"Received Flushed: {parsed}")
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {parsed}")
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(
|
||||
f"{self} warning: "
|
||||
f"{parsed.get('description', 'Unknown warning')}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Received unknown message type: {parsed}")
|
||||
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Not JSON — treat as raw audio bytes
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
payload,
|
||||
self.sample_rate,
|
||||
1,
|
||||
context_id=self._context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("TTS response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("TTS response processor stopped")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by sending Clear message to Deepgram.
|
||||
|
||||
The Clear message will clear Deepgram's internal text buffer and stop
|
||||
sending audio, allowing for a new response to be generated.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._ttfb_started = False
|
||||
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Clear"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Clear message: {e}")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
generation of audio from Deepgram's internal text buffer.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Flush"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Flush message: {e}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram TTS on SageMaker.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
context_id: The context ID for tracking audio frames.
|
||||
|
||||
Yields:
|
||||
Frame: TTSStartedFrame, then None (audio comes asynchronously via
|
||||
the response processor).
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._ttfb_started:
|
||||
await self.start_ttfb_metrics()
|
||||
self._ttfb_started = True
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
|
||||
await self._client.send_json({"type": "Speak", "text": text})
|
||||
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
@@ -4,445 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
"""Deprecated: use ``pipecat.services.deepgram.sagemaker.stt`` instead."""
|
||||
|
||||
This module provides a Pipecat STT service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time transcription with support for interim results, multiple
|
||||
languages, and various Deepgram features.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
warnings.warn(
|
||||
"Module `pipecat.services.deepgram.stt_sagemaker` is deprecated, "
|
||||
"use `pipecat.services.deepgram.sagemaker.stt` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.deepgram.stt import _DeepgramSTTSettingsBase
|
||||
from pipecat.services.settings import STTSettings
|
||||
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from deepgram import LiveOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepgramSageMakerSTTService, you need to `pip install pipecat-ai[deepgram,sagemaker]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepgramSageMakerSTTSettings(_DeepgramSTTSettingsBase):
|
||||
"""Settings for the Deepgram SageMaker STT service.
|
||||
|
||||
See ``_DeepgramSTTSettingsBase`` for full documentation.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DeepgramSageMakerSTTService(STTService):
|
||||
"""Deepgram speech-to-text service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech recognition using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
transcription with support for interim results, speaker diarization, and
|
||||
multiple languages.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- Deepgram SDK for LiveOptions configuration
|
||||
|
||||
Example::
|
||||
|
||||
stt = DeepgramSageMakerSTTService(
|
||||
endpoint_name="my-deepgram-endpoint",
|
||||
region="us-east-2",
|
||||
live_options=LiveOptions(
|
||||
model="nova-3",
|
||||
language="en",
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
),
|
||||
)
|
||||
"""
|
||||
|
||||
_settings: DeepgramSageMakerSTTSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
live_options: Optional[LiveOptions] = None,
|
||||
ttfs_p99_latency: Optional[float] = DEEPGRAM_SAGEMAKER_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram model
|
||||
deployed (e.g., "my-deepgram-nova-3-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
sample_rate: Audio sample rate in Hz. If None, uses value from
|
||||
live_options or defaults to the value from StartFrame.
|
||||
live_options: Deepgram LiveOptions configuration. Treated as a
|
||||
delta from a set of sensible defaults — only the fields you
|
||||
set are overridden; all others keep their default values.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
|
||||
default_options = LiveOptions(
|
||||
encoding="linear16",
|
||||
language=Language.EN,
|
||||
model="nova-3",
|
||||
channels=1,
|
||||
interim_results=True,
|
||||
punctuate=True,
|
||||
)
|
||||
|
||||
settings = DeepgramSageMakerSTTSettings(
|
||||
model=default_options.model,
|
||||
language=default_options.language,
|
||||
live_options=default_options,
|
||||
)
|
||||
if live_options:
|
||||
settings._merge_live_options_delta(live_options)
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
settings=settings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and warn about unhandled changes."""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
|
||||
return changed
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker STT service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to Deepgram for transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
Frame: None (transcription results come via BiDi stream callbacks).
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_audio_chunk(audio)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram query string from settings, creates the BiDi client,
|
||||
starts the streaming session, and launches background tasks for processing
|
||||
responses and sending KeepAlive messages.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram on SageMaker...")
|
||||
|
||||
live_options = LiveOptions(
|
||||
**{**self._settings.live_options.to_dict(), "sample_rate": self.sample_rate}
|
||||
)
|
||||
|
||||
# Build query string from live_options, converting booleans to strings
|
||||
query_params = {}
|
||||
for key, value in live_options.to_dict().items():
|
||||
if value is not None:
|
||||
# Convert boolean values to lowercase strings for Deepgram API
|
||||
if isinstance(value, bool):
|
||||
query_params[key] = str(value).lower()
|
||||
else:
|
||||
query_params[key] = str(value)
|
||||
|
||||
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
|
||||
|
||||
# Create BiDi client
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/listen",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
# Start the session
|
||||
await self._client.start_session()
|
||||
|
||||
# Start processing responses in the background
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
# Start keepalive task to maintain connection
|
||||
self._keepalive_task = self.create_task(self._send_keepalive())
|
||||
|
||||
logger.debug("Connected to Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a CloseStream message to Deepgram, cancels background tasks
|
||||
(KeepAlive and response processing), and closes the BiDi session.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram on SageMaker...")
|
||||
|
||||
# Send CloseStream message to Deepgram
|
||||
try:
|
||||
await self._client.send_json({"type": "CloseStream"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send CloseStream message: {e}")
|
||||
|
||||
# Cancel keepalive task
|
||||
if self._keepalive_task and not self._keepalive_task.done():
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
|
||||
# Cancel response processing task
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
# Close the BiDi session
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _send_keepalive(self):
|
||||
"""Send periodic KeepAlive messages to maintain the connection.
|
||||
|
||||
Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
|
||||
connection is active. This prevents the connection from timing out during
|
||||
periods of silence.
|
||||
"""
|
||||
while self._client and self._client.is_active:
|
||||
await asyncio.sleep(5)
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "KeepAlive"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send KeepAlive: {e}")
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream, decodes the payload,
|
||||
parses JSON responses from Deepgram, and processes transcription results.
|
||||
Runs as a background task until the connection is closed or cancelled.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
# Check if this is a PayloadPart with bytes
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
response_data = result.value.bytes_.decode("utf-8")
|
||||
|
||||
try:
|
||||
# Parse JSON response from Deepgram
|
||||
parsed = json.loads(response_data)
|
||||
|
||||
# Extract and process transcript if available
|
||||
if "channel" in parsed:
|
||||
await self._handle_transcript_response(parsed)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Non-JSON response: {response_data}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("Response processor stopped")
|
||||
|
||||
async def _handle_transcript_response(self, parsed: dict):
|
||||
"""Handle a transcript response from Deepgram.
|
||||
|
||||
Extracts the transcript text, determines if it's final or interim, extracts
|
||||
language information, and pushes the appropriate frame (TranscriptionFrame
|
||||
or InterimTranscriptionFrame) downstream.
|
||||
|
||||
Args:
|
||||
parsed: The parsed JSON response from Deepgram containing channel,
|
||||
alternatives, transcript, and metadata.
|
||||
"""
|
||||
alternatives = parsed.get("channel", {}).get("alternatives", [])
|
||||
if not alternatives or not alternatives[0].get("transcript"):
|
||||
return
|
||||
|
||||
transcript = alternatives[0]["transcript"]
|
||||
if not transcript.strip():
|
||||
return
|
||||
|
||||
is_final = parsed.get("is_final", False)
|
||||
|
||||
# Extract language if available
|
||||
language = None
|
||||
if alternatives[0].get("languages"):
|
||||
language = alternatives[0]["languages"][0]
|
||||
language = Language(language)
|
||||
|
||||
if is_final:
|
||||
# Check if this response is from a finalize() call.
|
||||
# Only mark as finalized when both we requested it AND Deepgram confirms it.
|
||||
from_finalize = parsed.get("from_finalize", False)
|
||||
if from_finalize:
|
||||
self.confirm_finalize()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
# Interim transcription
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=parsed,
|
||||
)
|
||||
)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing.
|
||||
|
||||
This method is decorated with @traced_stt for observability and tracing
|
||||
integration. The actual transcription processing is handled by the parent
|
||||
class and observers.
|
||||
|
||||
Args:
|
||||
transcript: The transcribed text.
|
||||
is_final: Whether this is a final transcription result.
|
||||
language: The detected language of the transcription, if available.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _start_metrics(self):
|
||||
"""Start processing metrics collection."""
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Deepgram SageMaker-specific handling.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# https://developers.deepgram.com/docs/finalize
|
||||
# Mark that we're awaiting a from_finalize response
|
||||
self.request_finalize()
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Finalize"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending Finalize message: {e}")
|
||||
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")
|
||||
from pipecat.services.deepgram.sagemaker.stt import * # noqa: E402, F401, F403
|
||||
|
||||
@@ -4,357 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
"""Deprecated: use ``pipecat.services.deepgram.sagemaker.tts`` instead."""
|
||||
|
||||
This module provides a Pipecat TTS service that connects to Deepgram models
|
||||
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
|
||||
low-latency real-time speech synthesis with support for interruptions and
|
||||
streaming audio output.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
warnings.warn(
|
||||
"Module `pipecat.services.deepgram.tts_sagemaker` is deprecated, "
|
||||
"use `pipecat.services.deepgram.sagemaker.tts` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepgramSageMakerTTSSettings(TTSSettings):
|
||||
"""Settings for Deepgram SageMaker TTS service.
|
||||
|
||||
Parameters:
|
||||
encoding: Audio encoding format (e.g. "linear16").
|
||||
"""
|
||||
|
||||
encoding: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
|
||||
class DeepgramSageMakerTTSService(TTSService):
|
||||
"""Deepgram text-to-speech service for AWS SageMaker.
|
||||
|
||||
Provides real-time speech synthesis using Deepgram models deployed on
|
||||
AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
|
||||
audio generation with support for interruptions via the Clear message.
|
||||
|
||||
Requirements:
|
||||
|
||||
- AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
|
||||
- A deployed SageMaker endpoint with Deepgram TTS model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker
|
||||
- ``pipecat-ai[sagemaker]`` installed
|
||||
|
||||
Example::
|
||||
|
||||
tts = DeepgramSageMakerTTSService(
|
||||
endpoint_name="my-deepgram-tts-endpoint",
|
||||
region="us-east-2",
|
||||
voice="aura-2-helena-en",
|
||||
)
|
||||
"""
|
||||
|
||||
_settings: DeepgramSageMakerTTSSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint_name: str,
|
||||
region: str,
|
||||
voice: str = "aura-2-helena-en",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint with Deepgram TTS model
|
||||
deployed (e.g., "my-deepgram-tts-endpoint").
|
||||
region: AWS region where the endpoint is deployed (e.g., "us-east-2").
|
||||
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses the value from StartFrame.
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to the parent TTSService.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
settings=DeepgramSageMakerTTSSettings(
|
||||
model=voice,
|
||||
voice=voice,
|
||||
language=None,
|
||||
encoding=encoding,
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
self._client: Optional[SageMakerBidiClient] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._context_id: Optional[str] = None
|
||||
self._ttfb_started: bool = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Deepgram SageMaker TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Deepgram SageMaker TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for LLM response end.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._ttfb_started = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the SageMaker endpoint and start the BiDi session.
|
||||
|
||||
Builds the Deepgram TTS query string, creates the BiDi client,
|
||||
starts the streaming session, and launches a background task for processing
|
||||
responses.
|
||||
"""
|
||||
logger.debug("Connecting to Deepgram TTS on SageMaker...")
|
||||
|
||||
query_string = (
|
||||
f"model={self._settings.voice}&encoding={self._settings.encoding}"
|
||||
f"&sample_rate={self.sample_rate}"
|
||||
)
|
||||
|
||||
self._client = SageMakerBidiClient(
|
||||
endpoint_name=self._endpoint_name,
|
||||
region=self._region,
|
||||
model_invocation_path="v1/speak",
|
||||
model_query_string=query_string,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._client.start_session()
|
||||
|
||||
self._response_task = self.create_task(self._process_responses())
|
||||
|
||||
logger.debug("Connected to Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_connected")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
await self._call_event_handler("on_connection_error", str(e))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the SageMaker endpoint.
|
||||
|
||||
Sends a Close message to Deepgram, cancels the response processing task,
|
||||
and closes the BiDi session. Safe to call multiple times.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
logger.debug("Disconnecting from Deepgram TTS on SageMaker...")
|
||||
|
||||
try:
|
||||
await self._client.send_json({"type": "Close"})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send Close message: {e}")
|
||||
|
||||
if self._response_task and not self._response_task.done():
|
||||
await self.cancel_task(self._response_task)
|
||||
|
||||
await self._client.close_session()
|
||||
|
||||
logger.debug("Disconnected from Deepgram TTS on SageMaker")
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and reconnect if necessary.
|
||||
|
||||
Since all settings are part of the SageMaker session query string,
|
||||
any setting change requires reconnecting to apply the new values.
|
||||
"""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
# Deepgram uses voice as the model, so keep them in sync for metrics
|
||||
if "voice" in changed:
|
||||
self._settings.model = self._settings.voice
|
||||
self._sync_model_name_to_metrics()
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# await self._disconnect()
|
||||
# await self._connect()
|
||||
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
|
||||
return changed
|
||||
|
||||
async def _process_responses(self):
|
||||
"""Process streaming responses from Deepgram TTS on SageMaker.
|
||||
|
||||
Continuously receives responses from the BiDi stream. Attempts to decode
|
||||
each payload as UTF-8 JSON for control messages (Flushed, Cleared, Metadata,
|
||||
Warning). If decoding fails, treats the payload as raw audio bytes and pushes
|
||||
a TTSAudioRawFrame downstream.
|
||||
"""
|
||||
try:
|
||||
while self._client and self._client.is_active:
|
||||
result = await self._client.receive_response()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
if hasattr(result, "value") and hasattr(result.value, "bytes_"):
|
||||
if result.value.bytes_:
|
||||
payload = result.value.bytes_
|
||||
|
||||
# Try to decode as JSON control message first
|
||||
try:
|
||||
response_data = payload.decode("utf-8")
|
||||
parsed = json.loads(response_data)
|
||||
msg_type = parsed.get("type")
|
||||
|
||||
if msg_type == "Metadata":
|
||||
logger.trace(f"Received metadata: {parsed}")
|
||||
elif msg_type == "Flushed":
|
||||
logger.trace(f"Received Flushed: {parsed}")
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {parsed}")
|
||||
elif msg_type == "Warning":
|
||||
logger.warning(
|
||||
f"{self} warning: "
|
||||
f"{parsed.get('description', 'Unknown warning')}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Received unknown message type: {parsed}")
|
||||
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Not JSON — treat as raw audio bytes
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
payload,
|
||||
self.sample_rate,
|
||||
1,
|
||||
context_id=self._context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("TTS response processor cancelled")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
logger.debug("TTS response processor stopped")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by sending Clear message to Deepgram.
|
||||
|
||||
The Clear message will clear Deepgram's internal text buffer and stop
|
||||
sending audio, allowing for a new response to be generated.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._ttfb_started = False
|
||||
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Clear"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Clear message: {e}")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
generation of audio from Deepgram's internal text buffer.
|
||||
"""
|
||||
if self._client and self._client.is_active:
|
||||
try:
|
||||
await self._client.send_json({"type": "Flush"})
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Flush message: {e}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram TTS on SageMaker.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
context_id: The context ID for tracking audio frames.
|
||||
|
||||
Yields:
|
||||
Frame: TTSStartedFrame, then None (audio comes asynchronously via
|
||||
the response processor).
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._ttfb_started:
|
||||
await self.start_ttfb_metrics()
|
||||
self._ttfb_started = True
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
|
||||
await self._client.send_json({"type": "Speak", "text": text})
|
||||
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
from pipecat.services.deepgram.sagemaker.tts import * # noqa: E402, F401, F403
|
||||
|
||||
Reference in New Issue
Block a user