Add stitching support and enhancements for NvidiaTTSService

This commit is contained in:
sathwika
2026-04-07 14:49:45 +05:30
parent 18852adc28
commit bc009d8f98

View File

@@ -4,16 +4,29 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA Riva text-to-speech service implementation.
"""NVIDIA Nemotron Speech text-to-speech service implementation.
This module provides integration with NVIDIA Riva's TTS services through
This module provides integration with NVIDIA Nemotron Speech's TTS services through
gRPC API for high-quality speech synthesis.
Refer to the NVIDIA TTS NIM documentation for usage, customization,
and local deployment steps:
https://docs.nvidia.com/nim/speech/latest/tts/
For zero-shot voice cloning, request access to the Magpie TTS Zero-Shot model:
https://developer.nvidia.com/riva-tts-zeroshot-models
Local or private cloud deployment is recommended for best zero-shot performance.
"""
import asyncio
import os
import queue
import textwrap
import threading
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, AsyncIterator, Generator, Mapping, Optional
from pathlib import Path
from typing import Any, AsyncGenerator, Mapping, Optional
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -24,21 +37,28 @@ from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
try:
import grpc
import riva.client
import riva.client.proto.riva_tts_pb2 as rtts
from riva.client.proto.riva_audio_pb2 import AudioEncoding
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
logger.error(
"In order to use NVIDIA Nemotron Speech TTS, you need to `pip install pipecat-ai[nvidia]`."
)
raise Exception(f"Missing module: {e}")
@@ -54,18 +74,19 @@ class NvidiaTTSSettings(TTSSettings):
class NvidiaTTSService(TTSService):
"""NVIDIA Riva text-to-speech service.
"""NVIDIA Nemotron Speech text-to-speech service.
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
Provides high-quality text-to-speech synthesis using NVIDIA Nemotron Speech's
cloud-based TTS models. Supports multiple voices, languages, and
configurable quality settings.
"""
Settings = NvidiaTTSSettings
_settings: Settings
_MAX_CHUNK_LEN = 200
class InputParams(BaseModel):
"""Input parameters for Riva TTS configuration.
"""Input parameters for Nemotron Speech TTS configuration.
.. deprecated:: 0.0.105
Use ``NvidiaTTSService.Settings`` directly via the ``settings`` parameter instead.
@@ -81,7 +102,7 @@ class NvidiaTTSService(TTSService):
def __init__(
self,
*,
api_key: str,
api_key: Optional[str] = None,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: Optional[str] = None,
sample_rate: Optional[int] = None,
@@ -92,13 +113,19 @@ class NvidiaTTSService(TTSService):
params: Optional[InputParams] = None,
settings: Optional[Settings] = None,
use_ssl: bool = True,
custom_dictionary: Optional[dict] = None,
encoding: Optional[AudioEncoding] = AudioEncoding.LINEAR_PCM,
zero_shot_audio_prompt_file: Optional[Path] = None,
audio_prompt_encoding: Optional[AudioEncoding] = AudioEncoding.ENCODING_UNSPECIFIED,
**kwargs,
):
"""Initialize the NVIDIA Riva TTS service.
"""Initialize the NVIDIA Nemotron Speech TTS service.
Args:
api_key: NVIDIA API key for authentication.
api_key: NVIDIA API key for authentication. Required when using the
cloud endpoint. Not needed for local deployments.
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
For local deployments, pass the local address (e.g. ``localhost:50051``).
voice_id: Voice model identifier. Defaults to multilingual Aria voice.
.. deprecated:: 0.0.105
@@ -113,7 +140,21 @@ class NvidiaTTSService(TTSService):
settings: Runtime-updatable settings. When provided alongside deprecated
parameters, ``settings`` values take precedence.
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
use_ssl: Whether to use SSL for the gRPC connection. Defaults to True
for the NVIDIA cloud endpoint. Set to False for local deployments.
custom_dictionary: Custom pronunciation dictionary mapping words
(graphemes) to IPA phonetic representations (phonemes),
e.g. ``{"NVIDIA": "ɛn.vɪ.diː"}``. See
https://docs.nvidia.com/nim/speech/latest/tts/phoneme-support.html
for the list of supported IPA phonemes.
encoding: Output audio encoding format. Defaults to ``AudioEncoding.LINEAR_PCM``.
zero_shot_audio_prompt_file: Path to audio prompt file for zero-shot voice
cloning. Audio length should be between 3-10 seconds. The file
is read once at init time and cached in memory. Requires the
Magpie TTS Zero-Shot model. See
https://docs.nvidia.com/nim/speech/latest/tts/voice-cloning.html
audio_prompt_encoding: Encoding of the zero-shot audio prompt file.
Defaults to ``AudioEncoding.ENCODING_UNSPECIFIED``.
**kwargs: Additional arguments passed to parent TTSService.
"""
# 1. Initialize default_settings with hardcoded defaults
@@ -144,7 +185,7 @@ class NvidiaTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_start_frame=False,
push_stop_frames=True,
settings=default_settings,
**kwargs,
@@ -155,18 +196,55 @@ class NvidiaTTSService(TTSService):
self._function_id = model_function_map.get("function_id")
self._use_ssl = use_ssl
self._custom_dictionary: Optional[str] = None
if custom_dictionary:
entries = [f"{k} {v}" for k, v in custom_dictionary.items()]
self._custom_dictionary = ",".join(entries)
self._encoding = encoding
self._audio_prompt_encoding = audio_prompt_encoding
self._zero_shot_audio_prompt_file = zero_shot_audio_prompt_file
self._zero_shot_audio_prompt: Optional[bytes] = None
if self._zero_shot_audio_prompt_file is not None:
if not self._zero_shot_audio_prompt_file.exists():
raise FileNotFoundError(
f"Zero-shot audio prompt file not found: {self._zero_shot_audio_prompt_file}"
)
with self._zero_shot_audio_prompt_file.open("rb") as f:
self._zero_shot_audio_prompt = f.read()
logger.debug(
f"Loaded zero-shot audio prompt from {self._zero_shot_audio_prompt_file} "
f"({len(self._zero_shot_audio_prompt)} bytes)"
)
self._service = None
self._config = None
# Persistent gRPC stream state for cross-sentence stitching
self._text_queue: Optional[queue.Queue] = None
self._synth_thread: Optional[threading.Thread] = None
self._response_task: Optional[asyncio.Task] = None
self._response_queue: asyncio.Queue = asyncio.Queue()
self._active_context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate metrics.
Returns:
True as this service supports metric generation.
"""
return True
async def set_model(self, model: str):
"""Set the TTS model.
.. deprecated:: 0.0.104
Model cannot be changed after initialization for NVIDIA Riva TTS.
Model cannot be changed after initialization for NVIDIA Nemotron Speech TTS.
Set model and function id in the constructor instead.
Example::
NvidiaTTSService(
api_key=...,
model_function_map={"function_id": "<UUID>", "model_name": "<model_name>"},
@@ -181,7 +259,7 @@ class NvidiaTTSService(TTSService):
warnings.simplefilter("always")
warnings.warn(
"'set_model' is deprecated. Model cannot be changed after initialization"
" for NVIDIA Riva TTS. Set model and function id in the constructor"
" for NVIDIA Nemotron Speech TTS. Set model and function id in the constructor"
" instead, e.g.: NvidiaTTSService(api_key=..., model_function_map="
"{'function_id': '<UUID>', 'model_name': '<model_name>'})",
DeprecationWarning,
@@ -191,13 +269,13 @@ class NvidiaTTSService(TTSService):
async def _update_settings(self, delta: Settings) -> dict[str, Any]:
"""Apply a settings delta.
Settings are stored but not applied to the active connection.
Settings are stored and will take effect on the next synthesis turn.
Mid-stream changes cannot be applied to the active gRPC connection.
"""
changed = await super()._update_settings(delta)
if not changed:
return changed
# TODO: reconnect gRPC client to apply changed settings.
self._warn_unhandled_updated_settings(changed)
if changed:
fields = ", ".join(sorted(changed))
logger.debug(f"{self.name}: settings updated [{fields}], will apply on next turn")
return changed
def _initialize_client(self):
@@ -216,14 +294,21 @@ class NvidiaTTSService(TTSService):
if not self._service:
return
# warm up the service
config = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
return config
try:
config = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
return config
except grpc.RpcError as e:
status = e.code().name if hasattr(e, "code") else "UNKNOWN"
details = e.details() if hasattr(e, "details") else str(e)
logger.error(
f"{self} failed to get synthesis config from server (gRPC {status}): {details}"
)
return None
async def start(self, frame: StartFrame):
"""Start the Cartesia TTS service.
"""Start the NVIDIA Nemotron Speech TTS service.
Args:
frame: The start frame containing initialization parameters.
@@ -233,65 +318,222 @@ class NvidiaTTSService(TTSService):
self._config = self._create_synthesis_config()
logger.debug(f"Initialized NvidiaTTSService with model: {self._settings.model}")
async def stop(self, frame: EndFrame):
"""Stop the NVIDIA Nemotron Speech TTS service and clean up resources.
Args:
frame: EndFrame indicating pipeline stop.
"""
await super().stop(frame)
await self._close_synthesis_stream()
async def cancel(self, frame: CancelFrame):
"""Cancel the NVIDIA Nemotron Speech TTS service operation.
Args:
frame: CancelFrame indicating operation cancellation.
"""
await super().cancel(frame)
await self._close_synthesis_stream()
def _start_synthesis_stream(self, context_id: str):
"""Start a persistent gRPC synthesis stream for the current turn.
Creates a queue-backed generator that feeds text to
``synthesize_online``. The gRPC stream stays open until a ``None``
sentinel is pushed into the queue.
"""
self._text_queue = queue.Queue()
self._active_context_id = context_id
self._response_queue = asyncio.Queue()
self._synth_thread = threading.Thread(
target=self._synthesis_thread_handler,
daemon=True,
name="nvidia-tts-synth",
)
self._synth_thread.start()
self._response_task = self.create_task(
self._response_consumer(), name="nvidia-tts-response"
)
def _build_base_request(self) -> rtts.SynthesizeSpeechRequest:
"""Build a reusable ``SynthesizeSpeechRequest`` with current settings."""
req = rtts.SynthesizeSpeechRequest(
text="",
language_code=str(self._settings.language or "en-US"),
sample_rate_hz=self.sample_rate,
encoding=self._encoding,
)
voice = self._settings.voice
if voice:
req.voice_name = voice
if self._zero_shot_audio_prompt is not None:
req.zero_shot_data.audio_prompt = self._zero_shot_audio_prompt
req.zero_shot_data.encoding = self._audio_prompt_encoding
req.zero_shot_data.quality = self._settings.quality
if self._custom_dictionary:
req.custom_dictionary = self._custom_dictionary
return req
def _synthesis_thread_handler(self):
"""Run ``SynthesizeOnline`` gRPC stream in a background thread.
Builds request objects directly to avoid a Python 3.12 compatibility
bug in ``riva.client.SpeechSynthesisService.synthesize_online``.
"""
base_req = self._build_base_request()
def request_generator():
while True:
text = self._text_queue.get()
if text is None:
break
base_req.text = text
yield base_req
try:
responses = self._service.stub.SynthesizeOnline(
request_generator(),
metadata=self._service.auth.get_auth_metadata(),
)
for resp in responses:
asyncio.run_coroutine_threadsafe(
self._response_queue.put(resp), self.get_event_loop()
)
except Exception as e:
logger.error(f"{self} gRPC synthesis stream error: {e}")
asyncio.run_coroutine_threadsafe(self._response_queue.put(e), self.get_event_loop())
finally:
asyncio.run_coroutine_threadsafe(self._response_queue.put(None), self.get_event_loop())
async def _response_consumer(self):
"""Consume gRPC responses and append audio to the active audio context."""
while True:
item = await self._response_queue.get()
if item is None:
break
if isinstance(item, Exception):
await self.push_error(f"{self} synthesis error: {item}")
break
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=item.audio,
sample_rate=self.sample_rate,
num_channels=1,
context_id=self._active_context_id,
)
await self.append_to_audio_context(self._active_context_id, frame)
async def _close_synthesis_stream(self):
"""Close the active gRPC synthesis stream.
Sends a sentinel to end the request generator, waits for the gRPC
thread to finish producing all remaining audio, then lets the
response consumer drain naturally before cleaning up.
"""
if self._text_queue is not None:
self._text_queue.put(None)
if self._synth_thread is not None:
await asyncio.to_thread(self._synth_thread.join)
self._synth_thread = None
self._text_queue = None
if self._response_task is not None:
try:
await self._response_task
except asyncio.CancelledError:
pass
self._response_task = None
self._active_context_id = None
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush the synthesis stream at the end of an LLM turn.
Sends a sentinel to the gRPC stream, waits for remaining audio,
then delegates to the base class for audio context cleanup.
Args:
context_id: The audio context to flush.
"""
await self._close_synthesis_stream()
await super().flush_audio(context_id)
async def on_audio_context_interrupted(self, context_id: str):
"""Handle interruption by closing the active synthesis stream."""
await self.stop_all_metrics()
await self._close_synthesis_stream()
@staticmethod
def _split_text_into_chunks(text: str) -> list[str]:
"""Split text into <=200-character chunks at whitespace boundaries.
Magpie stitches chunks seamlessly in the gRPC stream, so splitting
conservatively at 200 chars avoids max char limits without affecting audio
quality.
Args:
text: Input text to split.
Returns:
List of text chunks, each at most 200 characters.
"""
text = text.strip()
if not text:
return []
return textwrap.wrap(
text,
width=NvidiaTTSService._MAX_CHUNK_LEN,
break_long_words=True,
break_on_hyphens=False,
)
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using NVIDIA Riva TTS.
"""Generate speech from text using NVIDIA Nemotron Speech TTS.
On the first call for a turn, starts a persistent ``synthesize_online``
gRPC stream. Subsequent calls within the same turn feed text into the
existing stream, enabling Magpie's cross-sentence stitching.
Text is split into chunks respecting Magpie's per-request limits. Each chunk becomes
a separate request in the gRPC stream, stitched seamlessly by Magpie.
Args:
text: The text to synthesize into speech.
context_id: The context ID for tracking audio frames.
Yields:
Frame: Audio frames containing the synthesized speech data.
None on success. Audio is delivered asynchronously via the
response consumer. ErrorFrame on failure.
"""
def read_audio_responses() -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
responses = self._service.synthesize_online(
text,
self._settings.voice,
self._settings.language,
sample_rate_hz=self.sample_rate,
zero_shot_audio_prompt_file=None,
zero_shot_quality=self._settings.quality,
custom_dictionary={},
)
return responses
def async_next(it):
try:
return next(it)
except StopIteration:
return None
async def async_iterator(iterator) -> AsyncIterator[rtts.SynthesizeSpeechResponse]:
while True:
item = await asyncio.to_thread(async_next, iterator)
if item is None:
return
yield item
text = text.strip()
if not text or not any(c.isalnum() for c in text):
return
try:
assert self._service is not None, "TTS service not initialized"
assert self._config is not None, "Synthesis configuration not created"
# First call for this turn: create audio context and start gRPC stream
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
self._start_synthesis_stream(context_id)
logger.trace(f"{self}: Started synthesis stream for context {context_id}")
logger.debug(f"{self}: Generating TTS [{text}]")
responses = await asyncio.to_thread(read_audio_responses)
async for resp in async_iterator(responses):
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
sample_rate=self.sample_rate,
num_channels=1,
context_id=context_id,
)
yield frame
for chunk in self._split_text_into_chunks(text):
if any(c.isalnum() for c in chunk):
self._text_queue.put(chunk)
await self.start_tts_usage_metrics(text)
except asyncio.TimeoutError as e:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"{self} error: {e}")