From 05d65dfdd30c192979ef693bb719ff78a0f5d7a0 Mon Sep 17 00:00:00 2001 From: vipyne Date: Tue, 25 Nov 2025 13:41:26 -0600 Subject: [PATCH] Update NVIDIA NIM and Riva services to Nvidia - pip install pipecat-ai[nim] - pip install pipecat-ai[riva] + pip install pipecat-ai[nvidia] and - from pipecat.services.nim.llm import NimLLMService + from pipecat.services.nvidia.llm import NvidiaLLMService - from pipecat.services.riva.stt import RivaSTTService + from pipecat.services.nvidia.stt import NvidiaSTTService - from pipecat.services.riva.tts import RivaTTSService + from pipecat.services.nvidia.tts import NvidiaTTSService --- COMMUNITY_INTEGRATIONS.md | 2 +- pyproject.toml | 1 + src/pipecat/services/nim/llm.py | 4 +- src/pipecat/services/nvidia/__init__.py | 16 + src/pipecat/services/nvidia/llm.py | 105 ++++ src/pipecat/services/nvidia/stt.py | 712 ++++++++++++++++++++++++ src/pipecat/services/nvidia/tts.py | 239 ++++++++ src/pipecat/services/riva/stt.py | 16 +- src/pipecat/services/riva/tts.py | 16 +- 9 files changed, 1092 insertions(+), 19 deletions(-) create mode 100644 src/pipecat/services/nvidia/__init__.py create mode 100644 src/pipecat/services/nvidia/llm.py create mode 100644 src/pipecat/services/nvidia/stt.py create mode 100644 src/pipecat/services/nvidia/tts.py diff --git a/COMMUNITY_INTEGRATIONS.md b/COMMUNITY_INTEGRATIONS.md index 080d75ef2..a26836a52 100644 --- a/COMMUNITY_INTEGRATIONS.md +++ b/COMMUNITY_INTEGRATIONS.md @@ -79,7 +79,7 @@ Once your PR is submitted, post in the `#community-integrations` Discord channel **Examples:** -- [RivaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/riva/stt.py) +- [NvidiaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/nvidia/stt.py) - [FalSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/fal/stt.py) #### Key requirements: diff --git a/pyproject.toml b/pyproject.toml index 97552b708..14570a60d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,7 @@ qwen = [] remote-smart-turn = [] rime = [ "pipecat-ai[websockets-base]" ] riva = [ "nvidia-riva-client~=2.21.1" ] +nvidia = [ "nvidia-riva-client~=2.21.1" ] runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.122.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"] sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"] sambanova = [] diff --git a/src/pipecat/services/nim/llm.py b/src/pipecat/services/nim/llm.py index 07e970521..75bac586b 100644 --- a/src/pipecat/services/nim/llm.py +++ b/src/pipecat/services/nim/llm.py @@ -16,7 +16,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.openai.llm import OpenAILLMService -class NimLLMService(OpenAILLMService): +class NvidiaLLMService(OpenAILLMService): """A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API. This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining @@ -32,7 +32,7 @@ class NimLLMService(OpenAILLMService): model: str = "nvidia/llama-3.1-nemotron-70b-instruct", **kwargs, ): - """Initialize the NimLLMService. + """Initialize the NvidiaLLMService. Args: api_key: The API key for accessing NVIDIA's NIM API. diff --git a/src/pipecat/services/nvidia/__init__.py b/src/pipecat/services/nvidia/__init__.py new file mode 100644 index 000000000..8d7f00bb4 --- /dev/null +++ b/src/pipecat/services/nvidia/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * +from .stt import * +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "riva.[stt,tts]", "nvidia.[stt,tts]") +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "nim.llm", "nvidia.llm") diff --git a/src/pipecat/services/nvidia/llm.py b/src/pipecat/services/nvidia/llm.py new file mode 100644 index 000000000..75bac586b --- /dev/null +++ b/src/pipecat/services/nvidia/llm.py @@ -0,0 +1,105 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""NVIDIA NIM API service implementation. + +This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference +Microservice) API while maintaining compatibility with the OpenAI-style interface. +""" + +from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai.llm import OpenAILLMService + + +class NvidiaLLMService(OpenAILLMService): + """A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API. + + This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining + compatibility with the OpenAI-style interface. It specifically handles the difference + in token usage reporting between NIM (incremental) and OpenAI (final summary). + """ + + def __init__( + self, + *, + api_key: str, + base_url: str = "https://integrate.api.nvidia.com/v1", + model: str = "nvidia/llama-3.1-nemotron-70b-instruct", + **kwargs, + ): + """Initialize the NvidiaLLMService. + + Args: + api_key: The API key for accessing NVIDIA's NIM API. + base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1". + model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct". + **kwargs: Additional keyword arguments passed to OpenAILLMService. + """ + super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs) + # Counters for accumulating token usage metrics + self._prompt_tokens = 0 + self._completion_tokens = 0 + self._total_tokens = 0 + self._has_reported_prompt_tokens = False + self._is_processing = False + + async def _process_context(self, context: OpenAILLMContext | LLMContext): + """Process a context through the LLM and accumulate token usage metrics. + + This method overrides the parent class implementation to handle NVIDIA's + incremental token reporting style, accumulating the counts and reporting + them once at the end of processing. + + Args: + context: The context to process, containing messages and other information + needed for the LLM interaction. + """ + # Reset all counters and flags at the start of processing + self._prompt_tokens = 0 + self._completion_tokens = 0 + self._total_tokens = 0 + self._has_reported_prompt_tokens = False + self._is_processing = True + + try: + await super()._process_context(context) + finally: + self._is_processing = False + # Report final accumulated token usage at the end of processing + if self._prompt_tokens > 0 or self._completion_tokens > 0: + self._total_tokens = self._prompt_tokens + self._completion_tokens + tokens = LLMTokenUsage( + prompt_tokens=self._prompt_tokens, + completion_tokens=self._completion_tokens, + total_tokens=self._total_tokens, + ) + await super().start_llm_usage_metrics(tokens) + + async def start_llm_usage_metrics(self, tokens: LLMTokenUsage): + """Accumulate token usage metrics during processing. + + This method intercepts the incremental token updates from NVIDIA's API + and accumulates them instead of passing each update to the metrics system. + The final accumulated totals are reported at the end of processing. + + Args: + tokens: The token usage metrics for the current chunk of processing, + containing prompt_tokens and completion_tokens counts. + """ + # Only accumulate metrics during active processing + if not self._is_processing: + return + + # Record prompt tokens the first time we see them + if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0: + self._prompt_tokens = tokens.prompt_tokens + self._has_reported_prompt_tokens = True + + # Update completion tokens count if it has increased + if tokens.completion_tokens > self._completion_tokens: + self._completion_tokens = tokens.completion_tokens diff --git a/src/pipecat/services/nvidia/stt.py b/src/pipecat/services/nvidia/stt.py new file mode 100644 index 000000000..90634c749 --- /dev/null +++ b/src/pipecat/services/nvidia/stt.py @@ -0,0 +1,712 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription.""" + +import asyncio +from concurrent.futures import CancelledError as FuturesCancelledError +from typing import AsyncGenerator, List, Mapping, Optional + +from loguru import logger +from pydantic import BaseModel + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + TranscriptionFrame, +) +from pipecat.services.stt_service import SegmentedSTTService, STTService +from pipecat.transcriptions.language import Language, resolve_language +from pipecat.utils.time import time_now_iso8601 +from pipecat.utils.tracing.service_decorators import traced_stt + +try: + import riva.client + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[nvidia]`.") + raise Exception(f"Missing module: {e}") + + +def language_to_riva_language(language: Language) -> Optional[str]: + """Maps Language enum to Riva ASR language codes. + + Source: + https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr + + Args: + language: Language enum value. + + Returns: + Optional[str]: Riva language code or None if not supported. + """ + LANGUAGE_MAP = { + # Arabic + Language.AR: "ar-AR", + # English + Language.EN: "en-US", # Default to US + Language.EN_US: "en-US", + Language.EN_GB: "en-GB", + # French + Language.FR: "fr-FR", + Language.FR_FR: "fr-FR", + # German + Language.DE: "de-DE", + Language.DE_DE: "de-DE", + # Hindi + Language.HI: "hi-IN", + Language.HI_IN: "hi-IN", + # Italian + Language.IT: "it-IT", + Language.IT_IT: "it-IT", + # Japanese + Language.JA: "ja-JP", + Language.JA_JP: "ja-JP", + # Korean + Language.KO: "ko-KR", + Language.KO_KR: "ko-KR", + # Portuguese + Language.PT: "pt-BR", # Default to Brazilian + Language.PT_BR: "pt-BR", + # Russian + Language.RU: "ru-RU", + Language.RU_RU: "ru-RU", + # Spanish + Language.ES: "es-ES", # Default to Spain + Language.ES_ES: "es-ES", + Language.ES_US: "es-US", # US Spanish + } + + return resolve_language(language, LANGUAGE_MAP, use_base_code=False) + + +class NvidiaSTTService(STTService): + """Real-time speech-to-text service using NVIDIA Riva streaming ASR. + + Provides real-time transcription capabilities using NVIDIA's Riva ASR models + through streaming recognition. Supports interim results and continuous audio + processing for low-latency applications. + """ + + class InputParams(BaseModel): + """Configuration parameters for Riva STT service. + + Parameters: + language: Target language for transcription. Defaults to EN_US. + """ + + language: Optional[Language] = Language.EN_US + + def __init__( + self, + *, + api_key: str, + server: str = "grpc.nvcf.nvidia.com:443", + model_function_map: Mapping[str, str] = { + "function_id": "1598d209-5e27-4d3c-8079-4751568b1081", + "model_name": "parakeet-ctc-1.1b-asr", + }, + sample_rate: Optional[int] = None, + params: Optional[InputParams] = None, + **kwargs, + ): + """Initialize the Riva STT service. + + Args: + api_key: NVIDIA API key for authentication. + server: Riva server address. Defaults to NVIDIA Cloud Function endpoint. + model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model. + sample_rate: Audio sample rate in Hz. If None, uses pipeline default. + params: Additional configuration parameters for Riva. + **kwargs: Additional arguments passed to STTService. + """ + super().__init__(sample_rate=sample_rate, **kwargs) + + params = params or NvidiaSTTService.InputParams() + + self._api_key = api_key + self._profanity_filter = False + self._automatic_punctuation = True + self._no_verbatim_transcripts = False + self._language_code = params.language + self._boosted_lm_words = None + self._boosted_lm_score = 4.0 + self._start_history = -1 + self._start_threshold = -1.0 + self._stop_history = -1 + self._stop_threshold = -1.0 + self._stop_history_eou = -1 + self._stop_threshold_eou = -1.0 + self._custom_configuration = "" + self._function_id = model_function_map.get("function_id") + + self._settings = { + "language": str(params.language), + "profanity_filter": self._profanity_filter, + "automatic_punctuation": self._automatic_punctuation, + "verbatim_transcripts": not self._no_verbatim_transcripts, + "boosted_lm_words": self._boosted_lm_words, + "boosted_lm_score": self._boosted_lm_score, + } + + self.set_model_name(model_function_map.get("model_name")) + + metadata = [ + ["function-id", self._function_id], + ["authorization", f"Bearer {api_key}"], + ] + auth = riva.client.Auth(None, True, server, metadata) + + self._asr_service = riva.client.ASRService(auth) + + self._queue = None + self._config = None + self._thread_task = None + self._response_task = None + + def can_generate_metrics(self) -> bool: + """Check if this service can generate processing metrics. + + Returns: + False - this service does not support metrics generation. + """ + return False + + async def set_model(self, model: str): + """Set the ASR model for transcription. + + Args: + model: Model name to set. + + Note: + Model cannot be changed after initialization. Use model_function_map + parameter in constructor instead. + """ + logger.warning(f"Cannot set model after initialization. Set model and function id like so:") + example = {"function_id": "", "model_name": ""} + logger.warning( + f"{self.__class__.__name__}(api_key=, model_function_map={example})" + ) + + async def start(self, frame: StartFrame): + """Start the Riva STT service and initialize streaming configuration. + + Args: + frame: StartFrame indicating pipeline start. + """ + await super().start(frame) + + if self._config: + return + + config = riva.client.StreamingRecognitionConfig( + config=riva.client.RecognitionConfig( + encoding=riva.client.AudioEncoding.LINEAR_PCM, + language_code=self._language_code, + model="", + max_alternatives=1, + profanity_filter=self._profanity_filter, + enable_automatic_punctuation=self._automatic_punctuation, + verbatim_transcripts=not self._no_verbatim_transcripts, + sample_rate_hertz=self.sample_rate, + audio_channel_count=1, + ), + interim_results=True, + ) + + riva.client.add_word_boosting_to_config( + config, self._boosted_lm_words, self._boosted_lm_score + ) + + riva.client.add_endpoint_parameters_to_config( + config, + self._start_history, + self._start_threshold, + self._stop_history, + self._stop_history_eou, + self._stop_threshold, + self._stop_threshold_eou, + ) + riva.client.add_custom_configuration_to_config(config, self._custom_configuration) + + self._config = config + self._queue = asyncio.Queue() + + if not self._thread_task: + self._thread_task = self.create_task(self._thread_task_handler()) + + if not self._response_task: + self._response_queue = asyncio.Queue() + self._response_task = self.create_task(self._response_task_handler()) + + async def stop(self, frame: EndFrame): + """Stop the Riva STT service and clean up resources. + + Args: + frame: EndFrame indicating pipeline stop. + """ + await super().stop(frame) + await self._stop_tasks() + + async def cancel(self, frame: CancelFrame): + """Cancel the Riva STT service operation. + + Args: + frame: CancelFrame indicating operation cancellation. + """ + await super().cancel(frame) + await self._stop_tasks() + + async def _stop_tasks(self): + if self._thread_task: + await self.cancel_task(self._thread_task) + self._thread_task = None + + if self._response_task: + await self.cancel_task(self._response_task) + self._response_task = None + + def _response_handler(self): + responses = self._asr_service.streaming_response_generator( + audio_chunks=self, + streaming_config=self._config, + ) + for response in responses: + if not response.results: + continue + asyncio.run_coroutine_threadsafe( + self._response_queue.put(response), self.get_event_loop() + ) + + async def _thread_task_handler(self): + try: + self._thread_running = True + await asyncio.to_thread(self._response_handler) + except asyncio.CancelledError: + self._thread_running = False + raise + + @traced_stt + async def _handle_transcription( + self, transcript: str, is_final: bool, language: Optional[Language] = None + ): + """Handle a transcription result with tracing.""" + pass + + async def _handle_response(self, response): + for result in response.results: + if result and not result.alternatives: + continue + + transcript = result.alternatives[0].transcript + if transcript and len(transcript) > 0: + await self.stop_ttfb_metrics() + if result.is_final: + await self.stop_processing_metrics() + await self.push_frame( + TranscriptionFrame( + transcript, + self._user_id, + time_now_iso8601(), + self._language_code, + result=result, + ) + ) + await self._handle_transcription( + transcript=transcript, + is_final=result.is_final, + language=self._language_code, + ) + else: + await self.push_frame( + InterimTranscriptionFrame( + transcript, + self._user_id, + time_now_iso8601(), + self._language_code, + result=result, + ) + ) + + async def _response_task_handler(self): + while True: + response = await self._response_queue.get() + await self._handle_response(response) + self._response_queue.task_done() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Process audio data for speech-to-text transcription. + + Args: + audio: Raw audio bytes to transcribe. + + Yields: + None - transcription results are pushed to the pipeline via frames. + """ + await self.start_ttfb_metrics() + await self.start_processing_metrics() + await self._queue.put(audio) + yield None + + def __next__(self) -> bytes: + """Get the next audio chunk for Riva processing. + + Returns: + Audio bytes from the queue. + + Raises: + StopIteration: When the thread is no longer running. + """ + if not self._thread_running: + raise StopIteration + + try: + future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop()) + return future.result() + except FuturesCancelledError: + raise StopIteration + + def __iter__(self): + """Return iterator for audio chunk processing. + + Returns: + Self as iterator. + """ + return self + + +class RivaSegmentedSTTService(SegmentedSTTService): + """Speech-to-text service using NVIDIA Riva's offline/batch models. + + By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text + transcription on audio segments. It inherits from SegmentedSTTService to handle + audio buffering and speech detection. + """ + + class InputParams(BaseModel): + """Configuration parameters for Riva segmented STT service. + + Parameters: + language: Target language for transcription. Defaults to EN_US. + profanity_filter: Whether to filter profanity from results. + automatic_punctuation: Whether to add automatic punctuation. + verbatim_transcripts: Whether to return verbatim transcripts. + boosted_lm_words: List of words to boost in language model. + boosted_lm_score: Score boost for specified words. + """ + + language: Optional[Language] = Language.EN_US + profanity_filter: bool = False + automatic_punctuation: bool = True + verbatim_transcripts: bool = False + boosted_lm_words: Optional[List[str]] = None + boosted_lm_score: float = 4.0 + + def __init__( + self, + *, + api_key: str, + server: str = "grpc.nvcf.nvidia.com:443", + model_function_map: Mapping[str, str] = { + "function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd", + "model_name": "canary-1b-asr", + }, + sample_rate: Optional[int] = None, + params: Optional[InputParams] = None, + **kwargs, + ): + """Initialize the Riva segmented STT service. + + Args: + api_key: NVIDIA API key for authentication + server: Riva server address (defaults to NVIDIA Cloud Function endpoint) + model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID + sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate + params: Additional configuration parameters for Riva + **kwargs: Additional arguments passed to SegmentedSTTService + """ + super().__init__(sample_rate=sample_rate, **kwargs) + + params = params or RivaSegmentedSTTService.InputParams() + + # Set model name + self.set_model_name(model_function_map.get("model_name")) + + # Initialize Riva settings + self._api_key = api_key + self._server = server + self._function_id = model_function_map.get("function_id") + self._model_name = model_function_map.get("model_name") + + # Store the language as a Language enum and as a string + self._language_enum = params.language or Language.EN_US + self._language = self.language_to_service_language(self._language_enum) or "en-US" + + # Configure transcription parameters + self._profanity_filter = params.profanity_filter + self._automatic_punctuation = params.automatic_punctuation + self._verbatim_transcripts = params.verbatim_transcripts + self._boosted_lm_words = params.boosted_lm_words + self._boosted_lm_score = params.boosted_lm_score + + # Voice activity detection thresholds (use Riva defaults) + self._start_history = -1 + self._start_threshold = -1.0 + self._stop_history = -1 + self._stop_threshold = -1.0 + self._stop_history_eou = -1 + self._stop_threshold_eou = -1.0 + self._custom_configuration = "" + + # Create Riva client + self._config = None + self._asr_service = None + self._settings = {"language": self._language_enum} + + def language_to_service_language(self, language: Language) -> Optional[str]: + """Convert pipecat Language enum to Riva's language code. + + Args: + language: Language enum value. + + Returns: + Riva language code or None if not supported. + """ + return language_to_riva_language(language) + + def _initialize_client(self): + """Initialize the Riva ASR client with authentication metadata.""" + if self._asr_service is not None: + return + + # Set up authentication metadata for NVIDIA Cloud Functions + metadata = [ + ["function-id", self._function_id], + ["authorization", f"Bearer {self._api_key}"], + ] + + # Create authenticated client + auth = riva.client.Auth(None, True, self._server, metadata) + self._asr_service = riva.client.ASRService(auth) + + logger.info(f"Initialized RivaSegmentedSTTService with model: {self.model_name}") + + def _create_recognition_config(self): + """Create the Riva ASR recognition configuration.""" + # Create base configuration + config = riva.client.RecognitionConfig( + language_code=self._language, # Now using the string, not a tuple + max_alternatives=1, + profanity_filter=self._profanity_filter, + enable_automatic_punctuation=self._automatic_punctuation, + verbatim_transcripts=self._verbatim_transcripts, + ) + + # Add word boosting if specified + if self._boosted_lm_words: + riva.client.add_word_boosting_to_config( + config, self._boosted_lm_words, self._boosted_lm_score + ) + + # Add voice activity detection parameters + riva.client.add_endpoint_parameters_to_config( + config, + self._start_history, + self._start_threshold, + self._stop_history, + self._stop_history_eou, + self._stop_threshold, + self._stop_threshold_eou, + ) + + # Add any custom configuration + if self._custom_configuration: + riva.client.add_custom_configuration_to_config(config, self._custom_configuration) + + return config + + def can_generate_metrics(self) -> bool: + """Check if this service can generate processing metrics. + + Returns: + True - this service supports metrics generation. + """ + return True + + async def set_model(self, model: str): + """Set the ASR model for transcription. + + Args: + model: Model name to set. + + Note: + Model cannot be changed after initialization. Use model_function_map + parameter in constructor instead. + """ + logger.warning(f"Cannot set model after initialization. Set model and function id like so:") + example = {"function_id": "", "model_name": ""} + logger.warning( + f"{self.__class__.__name__}(api_key=, model_function_map={example})" + ) + + async def start(self, frame: StartFrame): + """Initialize the service when the pipeline starts. + + Args: + frame: StartFrame indicating pipeline start. + """ + await super().start(frame) + self._initialize_client() + self._config = self._create_recognition_config() + + async def set_language(self, language: Language): + """Set the language for the STT service. + + Args: + language: Target language for transcription. + """ + logger.info(f"Switching STT language to: [{language}]") + self._language_enum = language + self._language = self.language_to_service_language(language) or "en-US" + self._settings["language"] = language + + # Update configuration with new language + if self._config: + self._config.language_code = self._language + + @traced_stt + async def _handle_transcription( + self, transcript: str, is_final: bool, language: Optional[Language] = None + ): + """Handle a transcription result with tracing.""" + pass + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Transcribe an audio segment. + + Args: + audio: Raw audio bytes in WAV format (already converted by base class). + + Yields: + Frame: TranscriptionFrame containing the transcribed text. + """ + try: + await self.start_processing_metrics() + await self.start_ttfb_metrics() + + # Make sure the client is initialized + if self._asr_service is None: + self._initialize_client() + + # Make sure the config is created + if self._config is None: + self._config = self._create_recognition_config() + + # Type assertion to satisfy the IDE + assert self._asr_service is not None, "ASR service not initialized" + assert self._config is not None, "Recognition config not created" + + # Process audio with Riva ASR - explicitly request non-future response + raw_response = self._asr_service.offline_recognize(audio, self._config, future=False) + + await self.stop_ttfb_metrics() + await self.stop_processing_metrics() + + # Process the response - handle different possible return types + try: + # If it's a future-like object, get the result + if hasattr(raw_response, "result"): + response = raw_response.result() + else: + response = raw_response + + # Process transcription results + transcription_found = False + + # Now we can safely check results + # Type hint for the IDE + results = getattr(response, "results", []) + + for result in results: + alternatives = getattr(result, "alternatives", []) + if alternatives: + text = alternatives[0].transcript.strip() + if text: + logger.debug(f"Transcription: [{text}]") + yield TranscriptionFrame( + text, + self._user_id, + time_now_iso8601(), + self._language_enum, + ) + transcription_found = True + + await self._handle_transcription(text, True, self._language_enum) + + if not transcription_found: + logger.debug("No transcription results found in Riva response") + + except AttributeError as ae: + logger.error(f"Unexpected response structure from Riva: {ae}") + yield ErrorFrame(f"Unexpected Riva response format: {str(ae)}") + + except Exception as e: + logger.error(f"{self} exception: {e}") + yield ErrorFrame(error=f"{self} error: {e}") + + +class RivaSTTService(NvidiaSTTService): + """Deprecated speech-to-text service using NVIDIA Parakeet models. + + .. deprecated:: 0.0.96 + This class is deprecated. Use `NvidiaSTTService` instead for equivalent functionality + with Riva models by specifying the appropriate model_function_map. + """ + + def __init__( + self, + *, + api_key: str, + server: str = "grpc.nvcf.nvidia.com:443", + model_function_map: Mapping[str, str] = { + "function_id": "1598d209-5e27-4d3c-8079-4751568b1081", + "model_name": "parakeet-ctc-1.1b-asr", + }, + sample_rate: Optional[int] = None, + params: Optional[NvidiaSTTService.InputParams] = None, # Use parent class's type + **kwargs, + ): + """Initialize the Riva STT service. + + Args: + api_key: NVIDIA API key for authentication. + server: Riva server address. Defaults to NVIDIA Cloud Function endpoint. + model_function_map: Mapping containing 'function_id' and 'model_name' for Parakeet model. + sample_rate: Audio sample rate in Hz. If None, uses pipeline default. + params: Additional configuration parameters for Riva. + **kwargs: Additional arguments passed to NvidiaSTTService. + """ + super().__init__( + api_key=api_key, + server=server, + model_function_map=model_function_map, + sample_rate=sample_rate, + params=params, + **kwargs, + ) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`RivaSTTService` is deprecated, use `NvidiaSTTService` instead.", + DeprecationWarning, + ) diff --git a/src/pipecat/services/nvidia/tts.py b/src/pipecat/services/nvidia/tts.py new file mode 100644 index 000000000..d78943680 --- /dev/null +++ b/src/pipecat/services/nvidia/tts.py @@ -0,0 +1,239 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""NVIDIA Riva text-to-speech service implementation. + +This module provides integration with NVIDIA Riva's TTS services through +gRPC API for high-quality speech synthesis. +""" + +import asyncio +import os +from typing import AsyncGenerator, Mapping, Optional + +from pipecat.utils.tracing.service_decorators import traced_tts + +# Suppress gRPC fork warnings +os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" + +from loguru import logger +from pydantic import BaseModel + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.tts_service import TTSService +from pipecat.transcriptions.language import Language + +try: + import riva.client + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.") + raise Exception(f"Missing module: {e}") + +RIVA_TTS_TIMEOUT_SECS = 5 + + +class NvidiaTTSService(TTSService): + """NVIDIA Riva text-to-speech service. + + Provides high-quality text-to-speech synthesis using NVIDIA Riva's + cloud-based TTS models. Supports multiple voices, languages, and + configurable quality settings. + """ + + class InputParams(BaseModel): + """Input parameters for Riva TTS configuration. + + Parameters: + language: Language code for synthesis. Defaults to US English. + quality: Audio quality setting (0-100). Defaults to 20. + """ + + language: Optional[Language] = Language.EN_US + quality: Optional[int] = 20 + + def __init__( + self, + *, + api_key: str, + server: str = "grpc.nvcf.nvidia.com:443", + voice_id: str = "Magpie-Multilingual.EN-US.Aria", + sample_rate: Optional[int] = None, + model_function_map: Mapping[str, str] = { + "function_id": "877104f7-e885-42b9-8de8-f6e4c6303969", + "model_name": "magpie-tts-multilingual", + }, + params: Optional[InputParams] = None, + **kwargs, + ): + """Initialize the NVIDIA Riva TTS service. + + Args: + api_key: NVIDIA API key for authentication. + server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint. + voice_id: Voice model identifier. Defaults to multilingual Ray voice. + sample_rate: Audio sample rate. If None, uses service default. + model_function_map: Dictionary containing function_id and model_name for the TTS model. + params: Additional configuration parameters for TTS synthesis. + **kwargs: Additional arguments passed to parent TTSService. + """ + super().__init__(sample_rate=sample_rate, **kwargs) + + params = params or NvidiaTTSService.InputParams() + + self._api_key = api_key + self._voice_id = voice_id + self._language_code = params.language + self._quality = params.quality + self._function_id = model_function_map.get("function_id") + + self.set_model_name(model_function_map.get("model_name")) + self.set_voice(voice_id) + + metadata = [ + ["function-id", self._function_id], + ["authorization", f"Bearer {api_key}"], + ] + auth = riva.client.Auth(None, True, server, metadata) + + self._service = riva.client.SpeechSynthesisService(auth) + + # warm up the service + config_response = self._service.stub.GetRivaSynthesisConfig( + riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest() + ) + + async def set_model(self, model: str): + """Attempt to set the TTS model. + + Note: Model cannot be changed after initialization for Riva service. + + Args: + model: The model name to set (operation not supported). + """ + logger.warning(f"Cannot set model after initialization. Set model and function id like so:") + example = {"function_id": "", "model_name": ""} + logger.warning( + f"{self.__class__.__name__}(api_key=, model_function_map={example})" + ) + + @traced_tts + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + """Generate speech from text using NVIDIA Riva TTS. + + Args: + text: The text to synthesize into speech. + + Yields: + Frame: Audio frames containing the synthesized speech data. + """ + + def read_audio_responses(queue: asyncio.Queue): + def add_response(r): + asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop()) + + try: + responses = self._service.synthesize_online( + text, + self._voice_id, + self._language_code, + sample_rate_hz=self.sample_rate, + zero_shot_audio_prompt_file=None, + zero_shot_quality=self._quality, + custom_dictionary={}, + ) + for r in responses: + add_response(r) + add_response(None) + except Exception as e: + logger.error(f"{self} exception: {e}") + add_response(None) + + await self.start_ttfb_metrics() + yield TTSStartedFrame() + + logger.debug(f"{self}: Generating TTS [{text}]") + + try: + queue = asyncio.Queue() + await asyncio.to_thread(read_audio_responses, queue) + + # Wait for the thread to start. + resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS) + while resp: + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame( + audio=resp.audio, + sample_rate=self.sample_rate, + num_channels=1, + ) + yield frame + resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS) + except asyncio.TimeoutError: + logger.error(f"{self} timeout waiting for audio response") + yield ErrorFrame(error=f"{self} error: {e}") + + await self.start_tts_usage_metrics(text) + yield TTSStoppedFrame() + + +class RivaTTSService(NvidiaTTSService): + """Deprecated FastPitch TTS service. + + .. deprecated:: 0.0.96 + This class is deprecated. Use NvidiaTTSService instead for new implementations. + Provides backward compatibility for existing Riva TTS integrations. + """ + + def __init__( + self, + *, + api_key: str, + server: str = "grpc.nvcf.nvidia.com:443", + voice_id: str = "English-US.Female-1", + sample_rate: Optional[int] = None, + model_function_map: Mapping[str, str] = { + "function_id": "0149dedb-2be8-4195-b9a0-e57e0e14f972", + "model_name": "fastpitch-hifigan-tts", + }, + params: Optional[NvidiaTTSService.InputParams] = None, + **kwargs, + ): + """Initialize the deprecated Riva TTS service. + + Args: + api_key: NVIDIA API key for authentication. + server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint. + voice_id: Voice model identifier. Defaults to Female-1 voice. + sample_rate: Audio sample rate. If None, uses service default. + model_function_map: Dictionary containing function_id and model_name for FastPitch model. + params: Additional configuration parameters for TTS synthesis. + **kwargs: Additional arguments passed to parent NvidiaTTSService. + """ + super().__init__( + api_key=api_key, + server=server, + voice_id=voice_id, + sample_rate=sample_rate, + model_function_map=model_function_map, + params=params, + **kwargs, + ) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`RivaTTSService` is deprecated, use `NvidiaTTSService` instead.", + DeprecationWarning, + ) diff --git a/src/pipecat/services/riva/stt.py b/src/pipecat/services/riva/stt.py index 4dba62bcb..314e0dbce 100644 --- a/src/pipecat/services/riva/stt.py +++ b/src/pipecat/services/riva/stt.py @@ -32,7 +32,7 @@ try: except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.") + logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[nvidia]`.") raise Exception(f"Missing module: {e}") @@ -88,7 +88,7 @@ def language_to_riva_language(language: Language) -> Optional[str]: return resolve_language(language, LANGUAGE_MAP, use_base_code=False) -class RivaSTTService(STTService): +class NvidiaSTTService(STTService): """Real-time speech-to-text service using NVIDIA Riva streaming ASR. Provides real-time transcription capabilities using NVIDIA's Riva ASR models @@ -130,7 +130,7 @@ class RivaSTTService(STTService): """ super().__init__(sample_rate=sample_rate, **kwargs) - params = params or RivaSTTService.InputParams() + params = params or NvidiaSTTService.InputParams() self._api_key = api_key self._profanity_filter = False @@ -661,11 +661,11 @@ class RivaSegmentedSTTService(SegmentedSTTService): yield ErrorFrame(error=f"Unknown error occurred: {e}") -class ParakeetSTTService(RivaSTTService): +class ParakeetSTTService(NvidiaSTTService): """Deprecated speech-to-text service using NVIDIA Parakeet models. .. deprecated:: 0.0.66 - This class is deprecated. Use `RivaSTTService` instead for equivalent functionality + This class is deprecated. Use `NvidiaSTTService` instead for equivalent functionality with Parakeet models by specifying the appropriate model_function_map. """ @@ -679,7 +679,7 @@ class ParakeetSTTService(RivaSTTService): "model_name": "parakeet-ctc-1.1b-asr", }, sample_rate: Optional[int] = None, - params: Optional[RivaSTTService.InputParams] = None, # Use parent class's type + params: Optional[NvidiaSTTService.InputParams] = None, # Use parent class's type **kwargs, ): """Initialize the Parakeet STT service. @@ -690,7 +690,7 @@ class ParakeetSTTService(RivaSTTService): model_function_map: Mapping containing 'function_id' and 'model_name' for Parakeet model. sample_rate: Audio sample rate in Hz. If None, uses pipeline default. params: Additional configuration parameters for Riva. - **kwargs: Additional arguments passed to RivaSTTService. + **kwargs: Additional arguments passed to NvidiaSTTService. """ super().__init__( api_key=api_key, @@ -705,6 +705,6 @@ class ParakeetSTTService(RivaSTTService): with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( - "`ParakeetSTTService` is deprecated, use `RivaSTTService` instead.", + "`ParakeetSTTService` is deprecated, use `NvidiaSTTService` instead.", DeprecationWarning, ) diff --git a/src/pipecat/services/riva/tts.py b/src/pipecat/services/riva/tts.py index 370971068..7306e5d86 100644 --- a/src/pipecat/services/riva/tts.py +++ b/src/pipecat/services/riva/tts.py @@ -37,13 +37,13 @@ try: except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.") + logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.") raise Exception(f"Missing module: {e}") RIVA_TTS_TIMEOUT_SECS = 5 -class RivaTTSService(TTSService): +class NvidiaTTSService(TTSService): """NVIDIA Riva text-to-speech service. Provides high-quality text-to-speech synthesis using NVIDIA Riva's @@ -89,7 +89,7 @@ class RivaTTSService(TTSService): """ super().__init__(sample_rate=sample_rate, **kwargs) - params = params or RivaTTSService.InputParams() + params = params or NvidiaTTSService.InputParams() self._api_key = api_key self._voice_id = voice_id @@ -186,11 +186,11 @@ class RivaTTSService(TTSService): yield TTSStoppedFrame() -class FastPitchTTSService(RivaTTSService): +class FastPitchTTSService(NvidiaTTSService): """Deprecated FastPitch TTS service. .. deprecated:: 0.0.66 - This class is deprecated. Use RivaTTSService instead for new implementations. + This class is deprecated. Use NvidiaTTSService instead for new implementations. Provides backward compatibility for existing FastPitch TTS integrations. """ @@ -205,7 +205,7 @@ class FastPitchTTSService(RivaTTSService): "function_id": "0149dedb-2be8-4195-b9a0-e57e0e14f972", "model_name": "fastpitch-hifigan-tts", }, - params: Optional[RivaTTSService.InputParams] = None, + params: Optional[NvidiaTTSService.InputParams] = None, **kwargs, ): """Initialize the deprecated FastPitch TTS service. @@ -217,7 +217,7 @@ class FastPitchTTSService(RivaTTSService): sample_rate: Audio sample rate. If None, uses service default. model_function_map: Dictionary containing function_id and model_name for FastPitch model. params: Additional configuration parameters for TTS synthesis. - **kwargs: Additional arguments passed to parent RivaTTSService. + **kwargs: Additional arguments passed to parent NvidiaTTSService. """ super().__init__( api_key=api_key, @@ -233,6 +233,6 @@ class FastPitchTTSService(RivaTTSService): with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( - "`FastPitchTTSService` is deprecated, use `RivaTTSService` instead.", + "`FastPitchTTSService` is deprecated, use `NvidiaTTSService` instead.", DeprecationWarning, )