Add stitching support and enhancements for NvidiaTTSService
This commit is contained in:
@@ -4,16 +4,29 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva text-to-speech service implementation.
|
||||
"""NVIDIA Nemotron Speech text-to-speech service implementation.
|
||||
|
||||
This module provides integration with NVIDIA Riva's TTS services through
|
||||
This module provides integration with NVIDIA Nemotron Speech's TTS services through
|
||||
gRPC API for high-quality speech synthesis.
|
||||
|
||||
Refer to the NVIDIA TTS NIM documentation for usage, customization,
|
||||
and local deployment steps:
|
||||
https://docs.nvidia.com/nim/speech/latest/tts/
|
||||
|
||||
For zero-shot voice cloning, request access to the Magpie TTS Zero-Shot model:
|
||||
https://developer.nvidia.com/riva-tts-zeroshot-models
|
||||
|
||||
Local or private cloud deployment is recommended for best zero-shot performance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import queue
|
||||
import textwrap
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Generator, Mapping, Optional
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -24,21 +37,28 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import grpc
|
||||
import riva.client
|
||||
import riva.client.proto.riva_tts_pb2 as rtts
|
||||
from riva.client.proto.riva_audio_pb2 import AudioEncoding
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
logger.error(
|
||||
"In order to use NVIDIA Nemotron Speech TTS, you need to `pip install pipecat-ai[nvidia]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -54,18 +74,19 @@ class NvidiaTTSSettings(TTSSettings):
|
||||
|
||||
|
||||
class NvidiaTTSService(TTSService):
|
||||
"""NVIDIA Riva text-to-speech service.
|
||||
"""NVIDIA Nemotron Speech text-to-speech service.
|
||||
|
||||
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
|
||||
Provides high-quality text-to-speech synthesis using NVIDIA Nemotron Speech's
|
||||
cloud-based TTS models. Supports multiple voices, languages, and
|
||||
configurable quality settings.
|
||||
"""
|
||||
|
||||
Settings = NvidiaTTSSettings
|
||||
_settings: Settings
|
||||
_MAX_CHUNK_LEN = 200
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Riva TTS configuration.
|
||||
"""Input parameters for Nemotron Speech TTS configuration.
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``NvidiaTTSService.Settings`` directly via the ``settings`` parameter instead.
|
||||
@@ -81,7 +102,7 @@ class NvidiaTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_key: Optional[str] = None,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: Optional[str] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
@@ -92,13 +113,19 @@ class NvidiaTTSService(TTSService):
|
||||
params: Optional[InputParams] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
use_ssl: bool = True,
|
||||
custom_dictionary: Optional[dict] = None,
|
||||
encoding: Optional[AudioEncoding] = AudioEncoding.LINEAR_PCM,
|
||||
zero_shot_audio_prompt_file: Optional[Path] = None,
|
||||
audio_prompt_encoding: Optional[AudioEncoding] = AudioEncoding.ENCODING_UNSPECIFIED,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva TTS service.
|
||||
"""Initialize the NVIDIA Nemotron Speech TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
api_key: NVIDIA API key for authentication. Required when using the
|
||||
cloud endpoint. Not needed for local deployments.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
For local deployments, pass the local address (e.g. ``localhost:50051``).
|
||||
voice_id: Voice model identifier. Defaults to multilingual Aria voice.
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
@@ -113,7 +140,21 @@ class NvidiaTTSService(TTSService):
|
||||
|
||||
settings: Runtime-updatable settings. When provided alongside deprecated
|
||||
parameters, ``settings`` values take precedence.
|
||||
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
|
||||
use_ssl: Whether to use SSL for the gRPC connection. Defaults to True
|
||||
for the NVIDIA cloud endpoint. Set to False for local deployments.
|
||||
custom_dictionary: Custom pronunciation dictionary mapping words
|
||||
(graphemes) to IPA phonetic representations (phonemes),
|
||||
e.g. ``{"NVIDIA": "ɛn.vɪ.diː.ʌ"}``. See
|
||||
https://docs.nvidia.com/nim/speech/latest/tts/phoneme-support.html
|
||||
for the list of supported IPA phonemes.
|
||||
encoding: Output audio encoding format. Defaults to ``AudioEncoding.LINEAR_PCM``.
|
||||
zero_shot_audio_prompt_file: Path to audio prompt file for zero-shot voice
|
||||
cloning. Audio length should be between 3-10 seconds. The file
|
||||
is read once at init time and cached in memory. Requires the
|
||||
Magpie TTS Zero-Shot model. See
|
||||
https://docs.nvidia.com/nim/speech/latest/tts/voice-cloning.html
|
||||
audio_prompt_encoding: Encoding of the zero-shot audio prompt file.
|
||||
Defaults to ``AudioEncoding.ENCODING_UNSPECIFIED``.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
@@ -144,7 +185,7 @@ class NvidiaTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_start_frame=False,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -155,18 +196,55 @@ class NvidiaTTSService(TTSService):
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
self._use_ssl = use_ssl
|
||||
|
||||
self._custom_dictionary: Optional[str] = None
|
||||
if custom_dictionary:
|
||||
entries = [f"{k} {v}" for k, v in custom_dictionary.items()]
|
||||
self._custom_dictionary = ",".join(entries)
|
||||
self._encoding = encoding
|
||||
self._audio_prompt_encoding = audio_prompt_encoding
|
||||
|
||||
self._zero_shot_audio_prompt_file = zero_shot_audio_prompt_file
|
||||
self._zero_shot_audio_prompt: Optional[bytes] = None
|
||||
if self._zero_shot_audio_prompt_file is not None:
|
||||
if not self._zero_shot_audio_prompt_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Zero-shot audio prompt file not found: {self._zero_shot_audio_prompt_file}"
|
||||
)
|
||||
with self._zero_shot_audio_prompt_file.open("rb") as f:
|
||||
self._zero_shot_audio_prompt = f.read()
|
||||
logger.debug(
|
||||
f"Loaded zero-shot audio prompt from {self._zero_shot_audio_prompt_file} "
|
||||
f"({len(self._zero_shot_audio_prompt)} bytes)"
|
||||
)
|
||||
|
||||
self._service = None
|
||||
self._config = None
|
||||
|
||||
# Persistent gRPC stream state for cross-sentence stitching
|
||||
self._text_queue: Optional[queue.Queue] = None
|
||||
self._synth_thread: Optional[threading.Thread] = None
|
||||
self._response_task: Optional[asyncio.Task] = None
|
||||
self._response_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._active_context_id: Optional[str] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate metrics.
|
||||
|
||||
Returns:
|
||||
True as this service supports metric generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the TTS model.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Model cannot be changed after initialization for NVIDIA Riva TTS.
|
||||
Model cannot be changed after initialization for NVIDIA Nemotron Speech TTS.
|
||||
Set model and function id in the constructor instead.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
NvidiaTTSService(
|
||||
api_key=...,
|
||||
model_function_map={"function_id": "<UUID>", "model_name": "<model_name>"},
|
||||
@@ -181,7 +259,7 @@ class NvidiaTTSService(TTSService):
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"'set_model' is deprecated. Model cannot be changed after initialization"
|
||||
" for NVIDIA Riva TTS. Set model and function id in the constructor"
|
||||
" for NVIDIA Nemotron Speech TTS. Set model and function id in the constructor"
|
||||
" instead, e.g.: NvidiaTTSService(api_key=..., model_function_map="
|
||||
"{'function_id': '<UUID>', 'model_name': '<model_name>'})",
|
||||
DeprecationWarning,
|
||||
@@ -191,13 +269,13 @@ class NvidiaTTSService(TTSService):
|
||||
async def _update_settings(self, delta: Settings) -> dict[str, Any]:
|
||||
"""Apply a settings delta.
|
||||
|
||||
Settings are stored but not applied to the active connection.
|
||||
Settings are stored and will take effect on the next synthesis turn.
|
||||
Mid-stream changes cannot be applied to the active gRPC connection.
|
||||
"""
|
||||
changed = await super()._update_settings(delta)
|
||||
if not changed:
|
||||
return changed
|
||||
# TODO: reconnect gRPC client to apply changed settings.
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
if changed:
|
||||
fields = ", ".join(sorted(changed))
|
||||
logger.debug(f"{self.name}: settings updated [{fields}], will apply on next turn")
|
||||
return changed
|
||||
|
||||
def _initialize_client(self):
|
||||
@@ -216,14 +294,21 @@ class NvidiaTTSService(TTSService):
|
||||
if not self._service:
|
||||
return
|
||||
|
||||
# warm up the service
|
||||
config = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
return config
|
||||
try:
|
||||
config = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
return config
|
||||
except grpc.RpcError as e:
|
||||
status = e.code().name if hasattr(e, "code") else "UNKNOWN"
|
||||
details = e.details() if hasattr(e, "details") else str(e)
|
||||
logger.error(
|
||||
f"{self} failed to get synthesis config from server (gRPC {status}): {details}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Cartesia TTS service.
|
||||
"""Start the NVIDIA Nemotron Speech TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
@@ -233,65 +318,222 @@ class NvidiaTTSService(TTSService):
|
||||
self._config = self._create_synthesis_config()
|
||||
logger.debug(f"Initialized NvidiaTTSService with model: {self._settings.model}")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the NVIDIA Nemotron Speech TTS service and clean up resources.
|
||||
|
||||
Args:
|
||||
frame: EndFrame indicating pipeline stop.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._close_synthesis_stream()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the NVIDIA Nemotron Speech TTS service operation.
|
||||
|
||||
Args:
|
||||
frame: CancelFrame indicating operation cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._close_synthesis_stream()
|
||||
|
||||
def _start_synthesis_stream(self, context_id: str):
|
||||
"""Start a persistent gRPC synthesis stream for the current turn.
|
||||
|
||||
Creates a queue-backed generator that feeds text to
|
||||
``synthesize_online``. The gRPC stream stays open until a ``None``
|
||||
sentinel is pushed into the queue.
|
||||
"""
|
||||
self._text_queue = queue.Queue()
|
||||
self._active_context_id = context_id
|
||||
self._response_queue = asyncio.Queue()
|
||||
|
||||
self._synth_thread = threading.Thread(
|
||||
target=self._synthesis_thread_handler,
|
||||
daemon=True,
|
||||
name="nvidia-tts-synth",
|
||||
)
|
||||
self._synth_thread.start()
|
||||
self._response_task = self.create_task(
|
||||
self._response_consumer(), name="nvidia-tts-response"
|
||||
)
|
||||
|
||||
def _build_base_request(self) -> rtts.SynthesizeSpeechRequest:
|
||||
"""Build a reusable ``SynthesizeSpeechRequest`` with current settings."""
|
||||
req = rtts.SynthesizeSpeechRequest(
|
||||
text="",
|
||||
language_code=str(self._settings.language or "en-US"),
|
||||
sample_rate_hz=self.sample_rate,
|
||||
encoding=self._encoding,
|
||||
)
|
||||
voice = self._settings.voice
|
||||
if voice:
|
||||
req.voice_name = voice
|
||||
if self._zero_shot_audio_prompt is not None:
|
||||
req.zero_shot_data.audio_prompt = self._zero_shot_audio_prompt
|
||||
req.zero_shot_data.encoding = self._audio_prompt_encoding
|
||||
req.zero_shot_data.quality = self._settings.quality
|
||||
if self._custom_dictionary:
|
||||
req.custom_dictionary = self._custom_dictionary
|
||||
return req
|
||||
|
||||
def _synthesis_thread_handler(self):
|
||||
"""Run ``SynthesizeOnline`` gRPC stream in a background thread.
|
||||
|
||||
Builds request objects directly to avoid a Python 3.12 compatibility
|
||||
bug in ``riva.client.SpeechSynthesisService.synthesize_online``.
|
||||
"""
|
||||
base_req = self._build_base_request()
|
||||
|
||||
def request_generator():
|
||||
while True:
|
||||
text = self._text_queue.get()
|
||||
if text is None:
|
||||
break
|
||||
base_req.text = text
|
||||
yield base_req
|
||||
|
||||
try:
|
||||
responses = self._service.stub.SynthesizeOnline(
|
||||
request_generator(),
|
||||
metadata=self._service.auth.get_auth_metadata(),
|
||||
)
|
||||
for resp in responses:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(resp), self.get_event_loop()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} gRPC synthesis stream error: {e}")
|
||||
asyncio.run_coroutine_threadsafe(self._response_queue.put(e), self.get_event_loop())
|
||||
finally:
|
||||
asyncio.run_coroutine_threadsafe(self._response_queue.put(None), self.get_event_loop())
|
||||
|
||||
async def _response_consumer(self):
|
||||
"""Consume gRPC responses and append audio to the active audio context."""
|
||||
while True:
|
||||
item = await self._response_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
await self.push_error(f"{self} synthesis error: {item}")
|
||||
break
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=item.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=self._active_context_id,
|
||||
)
|
||||
await self.append_to_audio_context(self._active_context_id, frame)
|
||||
|
||||
async def _close_synthesis_stream(self):
|
||||
"""Close the active gRPC synthesis stream.
|
||||
|
||||
Sends a sentinel to end the request generator, waits for the gRPC
|
||||
thread to finish producing all remaining audio, then lets the
|
||||
response consumer drain naturally before cleaning up.
|
||||
"""
|
||||
if self._text_queue is not None:
|
||||
self._text_queue.put(None)
|
||||
|
||||
if self._synth_thread is not None:
|
||||
await asyncio.to_thread(self._synth_thread.join)
|
||||
self._synth_thread = None
|
||||
|
||||
self._text_queue = None
|
||||
|
||||
if self._response_task is not None:
|
||||
try:
|
||||
await self._response_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._response_task = None
|
||||
|
||||
self._active_context_id = None
|
||||
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush the synthesis stream at the end of an LLM turn.
|
||||
|
||||
Sends a sentinel to the gRPC stream, waits for remaining audio,
|
||||
then delegates to the base class for audio context cleanup.
|
||||
|
||||
Args:
|
||||
context_id: The audio context to flush.
|
||||
"""
|
||||
await self._close_synthesis_stream()
|
||||
await super().flush_audio(context_id)
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Handle interruption by closing the active synthesis stream."""
|
||||
await self.stop_all_metrics()
|
||||
await self._close_synthesis_stream()
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_chunks(text: str) -> list[str]:
|
||||
"""Split text into <=200-character chunks at whitespace boundaries.
|
||||
|
||||
Magpie stitches chunks seamlessly in the gRPC stream, so splitting
|
||||
conservatively at 200 chars avoids max char limits without affecting audio
|
||||
quality.
|
||||
|
||||
Args:
|
||||
text: Input text to split.
|
||||
|
||||
Returns:
|
||||
List of text chunks, each at most 200 characters.
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
return textwrap.wrap(
|
||||
text,
|
||||
width=NvidiaTTSService._MAX_CHUNK_LEN,
|
||||
break_long_words=True,
|
||||
break_on_hyphens=False,
|
||||
)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
"""Generate speech from text using NVIDIA Nemotron Speech TTS.
|
||||
|
||||
On the first call for a turn, starts a persistent ``synthesize_online``
|
||||
gRPC stream. Subsequent calls within the same turn feed text into the
|
||||
existing stream, enabling Magpie's cross-sentence stitching.
|
||||
|
||||
Text is split into chunks respecting Magpie's per-request limits. Each chunk becomes
|
||||
a separate request in the gRPC stream, stitched seamlessly by Magpie.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
context_id: The context ID for tracking audio frames.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech data.
|
||||
None on success. Audio is delivered asynchronously via the
|
||||
response consumer. ErrorFrame on failure.
|
||||
"""
|
||||
|
||||
def read_audio_responses() -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._settings.voice,
|
||||
self._settings.language,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._settings.quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
return responses
|
||||
|
||||
def async_next(it):
|
||||
try:
|
||||
return next(it)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
async def async_iterator(iterator) -> AsyncIterator[rtts.SynthesizeSpeechResponse]:
|
||||
while True:
|
||||
item = await asyncio.to_thread(async_next, iterator)
|
||||
if item is None:
|
||||
return
|
||||
yield item
|
||||
text = text.strip()
|
||||
if not text or not any(c.isalnum() for c in text):
|
||||
return
|
||||
|
||||
try:
|
||||
assert self._service is not None, "TTS service not initialized"
|
||||
assert self._config is not None, "Synthesis configuration not created"
|
||||
|
||||
# First call for this turn: create audio context and start gRPC stream
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._start_synthesis_stream(context_id)
|
||||
logger.trace(f"{self}: Started synthesis stream for context {context_id}")
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
responses = await asyncio.to_thread(read_audio_responses)
|
||||
|
||||
async for resp in async_iterator(responses):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
yield frame
|
||||
for chunk in self._split_text_into_chunks(text):
|
||||
if any(c.isalnum() for c in chunk):
|
||||
self._text_queue.put(chunk)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
Reference in New Issue
Block a user