Riva Service: add magpie-tts-multilingual model

This commit is contained in:
vipyne
2025-04-30 11:25:10 -05:00
parent 02c07755b0
commit 63a65627a2
3 changed files with 251 additions and 15 deletions

View File

@@ -16,8 +16,8 @@ from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.nim.llm import NimLLMService
from pipecat.services.riva.stt import ParakeetSTTService
from pipecat.services.riva.tts import FastPitchTTSService
from pipecat.services.riva.stt import ParakeetSTTService, RivaOfflineSTTService
from pipecat.services.riva.tts import RivaTTSService, FastPitchTTSService
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
@@ -37,11 +37,13 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
),
)
stt = ParakeetSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
stt = RivaOfflineSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
# stt = ParakeetSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct")
tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
# tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
messages = [
{

View File

@@ -13,11 +13,13 @@ from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.stt_service import SegmentedSTTService
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
@@ -27,20 +29,21 @@ try:
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.")
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`. Also set NVIDIA_API_KEY env var.")
raise Exception(f"Missing module: {e}")
class ParakeetSTTService(STTService):
class RivaSTTService(STTService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN_US
def __init__(
self,
*,
api_key: str,
api_key: str = None,
server: str = "grpc.nvcf.nvidia.com:443",
function_id: str = "1598d209-5e27-4d3c-8079-4751568b1081",
model_name: str = "parakeet-ctc-1.1b-asr",
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),
**kwargs,
@@ -61,7 +64,7 @@ class ParakeetSTTService(STTService):
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
self.set_model_name("parakeet-ctc-1.1b-asr")
self.set_model_name(model_name)
metadata = [
["function-id", function_id],
@@ -196,3 +199,196 @@ class ParakeetSTTService(STTService):
def __iter__(self):
return self
class RivaOfflineSTTService(SegmentedSTTService):
"""Speech-to-text service using Fal's Wizper API.
This service uses Fal's Wizper API to perform speech-to-text transcription on audio
segments. It inherits from SegmentedSTTService to handle audio buffering and speech detection.
Args:
api_key: NVIDIA_API_KEY.
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate.
params: Configuration parameters for Riva.
**kwargs: Additional arguments passed to SegmentedSTTService.
"""
class InputParams(BaseModel):
"""Configuration parameters for Fal's Wizper API.
Attributes:
language: Language of the audio input. Defaults to English.
task: Task to perform ('transcribe' or 'translate'). Defaults to 'transcribe'.
chunk_level: Level of chunking ('segment'). Defaults to 'segment'.
version: Version of Wizper model to use. Defaults to '3'.
"""
language: Optional[Language] = Language.EN
task: str = "transcribe"
chunk_level: str = "segment"
version: str = "3"
def __init__(
self,
*,
api_key: str = None,
server: str = "grpc.nvcf.nvidia.com:443",
function_id: str = "ee8dc628-76de-4acc-8595-1836e7e857bd",
model_name: str = "canary-1b-asr",
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)
self._api_key = api_key
self._profanity_filter = False
self._automatic_punctuation = False
self._no_verbatim_transcripts = False
self._language_code = params.language
self._boosted_lm_words = None
self._boosted_lm_score = 4.0
self._start_history = -1
self._start_threshold = -1.0
self._stop_history = -1
self._stop_threshold = -1.0
self._stop_history_eou = -1
self._stop_threshold_eou = -1.0
self._custom_configuration = ""
self.set_model_name(model_name)
metadata = [
["function-id", function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
self._asr_service = riva.client.ASRService(auth)
self._queue = asyncio.Queue()
self._config = None
self._thread_task = None
self._response_task = None
def can_generate_metrics(self) -> bool:
return False
async def start(self, frame: StartFrame):
await super().start(frame)
if self._config:
return
# config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
# encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self._language_code,
# model="",
max_alternatives=1,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=not self._no_verbatim_transcripts,
# sample_rate_hertz=self.sample_rate,
# audio_channel_count=1,
# enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,??
# ),
# interim_results=True,
)
riva.client.add_word_boosting_to_config(
config, self._boosted_lm_words, self._boosted_lm_score
)
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
self._config = config
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Transcribe an audio segment
Args:
audio: Raw audio bytes in WAV format (already converted by base class).
Yields:
Frame: TranscriptionFrame containing the transcribed text.
Note:
The audio is already in WAV format from the SegmentedSTTService.
Only non-empty transcriptions are yielded.
"""
try:
response = self._asr_service.offline_recognize(audio, self._config)
# response = riva.client.print_offline(response=self._asr_service.offline_recognize(audio, self._config))
print(f"_____stt.py * response: {response}")
# # Send to Fal directly (audio is already in WAV format from base class)
# data_uri = fal_client.encode(audio, "audio/x-wav")
# response = await self._fal_client.run(
# "fal-ai/wizper",
# arguments={"audio_url": data_uri, **self._settings},
# )
if response and "text" in response:
text = response["text"].strip()
if text: # Only yield non-empty text
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text, "", time_now_iso8601(), Language(self._settings["language"])
)
except Exception as e:
logger.error(f"Riva Offline STT error: {e}")
yield ErrorFrame(f"Riva Offline STT error: {str(e)}")
def __next__(self) -> bytes:
if not self._thread_running:
raise StopIteration
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
return future.result()
def __iter__(self):
return self
class ParakeetSTTService(RivaSTTService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN_US
def __init__(
self,
*,
api_key: str = None,
server: str = "grpc.nvcf.nvidia.com:443",
function_id: str = "1598d209-5e27-4d3c-8079-4751568b1081",
model_name: str = "parakeet-ctc-1.1b-asr",
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(
api_key=api_key,
server=server,
function_id=function_id,
model_name=model_name,
sample_rate=sample_rate,
params=params,
**kwargs,
)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`ParakeetSTTService` is deprecated, use `RivaSTTService` instead.",
DeprecationWarning,
)

View File

@@ -27,10 +27,10 @@ except ModuleNotFoundError as e:
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.")
raise Exception(f"Missing module: {e}")
FASTPITCH_TIMEOUT_SECS = 5
RIVA_TTS_TIMEOUT_SECS = 5
class FastPitchTTSService(TTSService):
class RivaTTSService(TTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN_US
quality: Optional[int] = 20
@@ -38,11 +38,12 @@ class FastPitchTTSService(TTSService):
def __init__(
self,
*,
api_key: str,
api_key: str = None,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "English-US.Female-1",
voice_id: str = "Magpie-Multilingual.EN-US.Male.Male-1",
sample_rate: Optional[int] = None,
function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
function_id: str = "877104f7-e885-42b9-8de8-f6e4c6303969",
model_name: str = "magpie-tts-multilingual",
params: InputParams = InputParams(),
**kwargs,
):
@@ -52,7 +53,7 @@ class FastPitchTTSService(TTSService):
self._language_code = params.language
self._quality = params.quality
self.set_model_name("fastpitch-hifigan-tts")
self.set_model_name(model_name)
self.set_voice(voice_id)
metadata = [
@@ -100,7 +101,7 @@ class FastPitchTTSService(TTSService):
await asyncio.to_thread(read_audio_responses, queue)
# Wait for the thread to start.
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
resp = await asyncio.wait_for(queue.get(), RIVA_TTS_TIMEOUT_SECS)
while resp:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
@@ -109,9 +110,46 @@ class FastPitchTTSService(TTSService):
num_channels=1,
)
yield frame
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
resp = await asyncio.wait_for(queue.get(), RIVA_TTS_TIMEOUT_SECS)
except asyncio.TimeoutError:
logger.error(f"{self} timeout waiting for audio response")
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()
class FastPitchTTSService(RivaTTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN_US
quality: Optional[int] = 20
def __init__(
self,
*,
api_key: str = None,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "English-US.Female-1",
sample_rate: Optional[int] = None,
function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
model_name: str = "fastpitch-hifigan-tts",
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(
api_key=api_key,
voice_id=voice_id,
sample_rate=sample_rate,
function_id=function_id,
model_name=model_name,
params=params,
**kwargs,
)
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`FastPitchTTSService` is deprecated, use `RivaTTSService` instead.",
DeprecationWarning,
)