From 63a65627a29ad29dfd554ea3139977b4eb10f511 Mon Sep 17 00:00:00 2001 From: vipyne Date: Wed, 30 Apr 2025 11:25:10 -0500 Subject: [PATCH] Riva Service: add magpie-tts-multilingual model --- .../07r-interruptible-riva-nim.py | 8 +- src/pipecat/services/riva/stt.py | 204 +++++++++++++++++- src/pipecat/services/riva/tts.py | 54 ++++- 3 files changed, 251 insertions(+), 15 deletions(-) diff --git a/examples/foundational/07r-interruptible-riva-nim.py b/examples/foundational/07r-interruptible-riva-nim.py index 915beda51..ca74dfd7a 100644 --- a/examples/foundational/07r-interruptible-riva-nim.py +++ b/examples/foundational/07r-interruptible-riva-nim.py @@ -16,8 +16,8 @@ from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.nim.llm import NimLLMService -from pipecat.services.riva.stt import ParakeetSTTService -from pipecat.services.riva.tts import FastPitchTTSService +from pipecat.services.riva.stt import ParakeetSTTService, RivaOfflineSTTService +from pipecat.services.riva.tts import RivaTTSService, FastPitchTTSService from pipecat.transports.base_transport import TransportParams from pipecat.transports.network.small_webrtc import SmallWebRTCTransport from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection @@ -37,11 +37,13 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac ), ) - stt = ParakeetSTTService(api_key=os.getenv("NVIDIA_API_KEY")) + stt = RivaOfflineSTTService(api_key=os.getenv("NVIDIA_API_KEY")) + # stt = ParakeetSTTService(api_key=os.getenv("NVIDIA_API_KEY")) llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct") tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY")) + # tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY")) messages = [ { diff --git a/src/pipecat/services/riva/stt.py b/src/pipecat/services/riva/stt.py index 6328bcb65..3676befa3 100644 --- a/src/pipecat/services/riva/stt.py +++ b/src/pipecat/services/riva/stt.py @@ -13,11 +13,13 @@ from pydantic import BaseModel from pipecat.frames.frames import ( CancelFrame, EndFrame, + ErrorFrame, Frame, InterimTranscriptionFrame, StartFrame, TranscriptionFrame, ) +from pipecat.services.stt_service import SegmentedSTTService from pipecat.services.stt_service import STTService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 @@ -27,20 +29,21 @@ try: except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.") + logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`. Also set NVIDIA_API_KEY env var.") raise Exception(f"Missing module: {e}") -class ParakeetSTTService(STTService): +class RivaSTTService(STTService): class InputParams(BaseModel): language: Optional[Language] = Language.EN_US def __init__( self, *, - api_key: str, + api_key: str = None, server: str = "grpc.nvcf.nvidia.com:443", function_id: str = "1598d209-5e27-4d3c-8079-4751568b1081", + model_name: str = "parakeet-ctc-1.1b-asr", sample_rate: Optional[int] = None, params: InputParams = InputParams(), **kwargs, @@ -61,7 +64,7 @@ class ParakeetSTTService(STTService): self._stop_threshold_eou = -1.0 self._custom_configuration = "" - self.set_model_name("parakeet-ctc-1.1b-asr") + self.set_model_name(model_name) metadata = [ ["function-id", function_id], @@ -196,3 +199,196 @@ class ParakeetSTTService(STTService): def __iter__(self): return self + +class RivaOfflineSTTService(SegmentedSTTService): + """Speech-to-text service using Fal's Wizper API. + + This service uses Fal's Wizper API to perform speech-to-text transcription on audio + segments. It inherits from SegmentedSTTService to handle audio buffering and speech detection. + + Args: + api_key: NVIDIA_API_KEY. + sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate. + params: Configuration parameters for Riva. + **kwargs: Additional arguments passed to SegmentedSTTService. + """ + + class InputParams(BaseModel): + """Configuration parameters for Fal's Wizper API. + + Attributes: + language: Language of the audio input. Defaults to English. + task: Task to perform ('transcribe' or 'translate'). Defaults to 'transcribe'. + chunk_level: Level of chunking ('segment'). Defaults to 'segment'. + version: Version of Wizper model to use. Defaults to '3'. + """ + + language: Optional[Language] = Language.EN + task: str = "transcribe" + chunk_level: str = "segment" + version: str = "3" + + def __init__( + self, + *, + api_key: str = None, + server: str = "grpc.nvcf.nvidia.com:443", + function_id: str = "ee8dc628-76de-4acc-8595-1836e7e857bd", + model_name: str = "canary-1b-asr", + sample_rate: Optional[int] = None, + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + self._api_key = api_key + self._profanity_filter = False + self._automatic_punctuation = False + self._no_verbatim_transcripts = False + self._language_code = params.language + self._boosted_lm_words = None + self._boosted_lm_score = 4.0 + self._start_history = -1 + self._start_threshold = -1.0 + self._stop_history = -1 + self._stop_threshold = -1.0 + self._stop_history_eou = -1 + self._stop_threshold_eou = -1.0 + self._custom_configuration = "" + + self.set_model_name(model_name) + + metadata = [ + ["function-id", function_id], + ["authorization", f"Bearer {api_key}"], + ] + auth = riva.client.Auth(None, True, server, metadata) + + self._asr_service = riva.client.ASRService(auth) + + self._queue = asyncio.Queue() + self._config = None + self._thread_task = None + self._response_task = None + + def can_generate_metrics(self) -> bool: + return False + + async def start(self, frame: StartFrame): + await super().start(frame) + + if self._config: + return + + # config = riva.client.StreamingRecognitionConfig( + config=riva.client.RecognitionConfig( + # encoding=riva.client.AudioEncoding.LINEAR_PCM, + language_code=self._language_code, + # model="", + max_alternatives=1, + profanity_filter=self._profanity_filter, + enable_automatic_punctuation=self._automatic_punctuation, + verbatim_transcripts=not self._no_verbatim_transcripts, + # sample_rate_hertz=self.sample_rate, + # audio_channel_count=1, + # enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,?? + # ), + # interim_results=True, + ) + + riva.client.add_word_boosting_to_config( + config, self._boosted_lm_words, self._boosted_lm_score + ) + + riva.client.add_endpoint_parameters_to_config( + config, + self._start_history, + self._start_threshold, + self._stop_history, + self._stop_history_eou, + self._stop_threshold, + self._stop_threshold_eou, + ) + riva.client.add_custom_configuration_to_config(config, self._custom_configuration) + + self._config = config + + + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Transcribe an audio segment + + Args: + audio: Raw audio bytes in WAV format (already converted by base class). + + Yields: + Frame: TranscriptionFrame containing the transcribed text. + + Note: + The audio is already in WAV format from the SegmentedSTTService. + Only non-empty transcriptions are yielded. + """ + try: + response = self._asr_service.offline_recognize(audio, self._config) + # response = riva.client.print_offline(response=self._asr_service.offline_recognize(audio, self._config)) + print(f"_____stt.py * response: {response}") + # # Send to Fal directly (audio is already in WAV format from base class) + # data_uri = fal_client.encode(audio, "audio/x-wav") + # response = await self._fal_client.run( + # "fal-ai/wizper", + # arguments={"audio_url": data_uri, **self._settings}, + # ) + + if response and "text" in response: + text = response["text"].strip() + if text: # Only yield non-empty text + logger.debug(f"Transcription: [{text}]") + yield TranscriptionFrame( + text, "", time_now_iso8601(), Language(self._settings["language"]) + ) + + except Exception as e: + logger.error(f"Riva Offline STT error: {e}") + yield ErrorFrame(f"Riva Offline STT error: {str(e)}") + + def __next__(self) -> bytes: + if not self._thread_running: + raise StopIteration + future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop()) + return future.result() + + def __iter__(self): + return self + +class ParakeetSTTService(RivaSTTService): + class InputParams(BaseModel): + language: Optional[Language] = Language.EN_US + + def __init__( + self, + *, + api_key: str = None, + server: str = "grpc.nvcf.nvidia.com:443", + function_id: str = "1598d209-5e27-4d3c-8079-4751568b1081", + model_name: str = "parakeet-ctc-1.1b-asr", + sample_rate: Optional[int] = None, + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__( + api_key=api_key, + server=server, + function_id=function_id, + model_name=model_name, + sample_rate=sample_rate, + params=params, + **kwargs, + ) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`ParakeetSTTService` is deprecated, use `RivaSTTService` instead.", + DeprecationWarning, + ) + diff --git a/src/pipecat/services/riva/tts.py b/src/pipecat/services/riva/tts.py index 0fd0c3ce0..f37064acb 100644 --- a/src/pipecat/services/riva/tts.py +++ b/src/pipecat/services/riva/tts.py @@ -27,10 +27,10 @@ except ModuleNotFoundError as e: logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.") raise Exception(f"Missing module: {e}") -FASTPITCH_TIMEOUT_SECS = 5 +RIVA_TTS_TIMEOUT_SECS = 5 -class FastPitchTTSService(TTSService): +class RivaTTSService(TTSService): class InputParams(BaseModel): language: Optional[Language] = Language.EN_US quality: Optional[int] = 20 @@ -38,11 +38,12 @@ class FastPitchTTSService(TTSService): def __init__( self, *, - api_key: str, + api_key: str = None, server: str = "grpc.nvcf.nvidia.com:443", - voice_id: str = "English-US.Female-1", + voice_id: str = "Magpie-Multilingual.EN-US.Male.Male-1", sample_rate: Optional[int] = None, - function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972", + function_id: str = "877104f7-e885-42b9-8de8-f6e4c6303969", + model_name: str = "magpie-tts-multilingual", params: InputParams = InputParams(), **kwargs, ): @@ -52,7 +53,7 @@ class FastPitchTTSService(TTSService): self._language_code = params.language self._quality = params.quality - self.set_model_name("fastpitch-hifigan-tts") + self.set_model_name(model_name) self.set_voice(voice_id) metadata = [ @@ -100,7 +101,7 @@ class FastPitchTTSService(TTSService): await asyncio.to_thread(read_audio_responses, queue) # Wait for the thread to start. - resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS) + resp = await asyncio.wait_for(queue.get(), RIVA_TTS_TIMEOUT_SECS) while resp: await self.stop_ttfb_metrics() frame = TTSAudioRawFrame( @@ -109,9 +110,46 @@ class FastPitchTTSService(TTSService): num_channels=1, ) yield frame - resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS) + resp = await asyncio.wait_for(queue.get(), RIVA_TTS_TIMEOUT_SECS) except asyncio.TimeoutError: logger.error(f"{self} timeout waiting for audio response") await self.start_tts_usage_metrics(text) yield TTSStoppedFrame() + + +class FastPitchTTSService(RivaTTSService): + class InputParams(BaseModel): + language: Optional[Language] = Language.EN_US + quality: Optional[int] = 20 + + def __init__( + self, + *, + api_key: str = None, + server: str = "grpc.nvcf.nvidia.com:443", + voice_id: str = "English-US.Female-1", + sample_rate: Optional[int] = None, + function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972", + model_name: str = "fastpitch-hifigan-tts", + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__( + api_key=api_key, + voice_id=voice_id, + sample_rate=sample_rate, + function_id=function_id, + model_name=model_name, + params=params, + **kwargs, + ) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`FastPitchTTSService` is deprecated, use `RivaTTSService` instead.", + DeprecationWarning, + ) +