Riva Service: add magpie-tts-multilingual model
This commit is contained in:
@@ -16,8 +16,8 @@ from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.services.riva.stt import ParakeetSTTService
|
||||
from pipecat.services.riva.tts import FastPitchTTSService
|
||||
from pipecat.services.riva.stt import ParakeetSTTService, RivaOfflineSTTService
|
||||
from pipecat.services.riva.tts import RivaTTSService, FastPitchTTSService
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
@@ -37,11 +37,13 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
),
|
||||
)
|
||||
|
||||
stt = ParakeetSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
stt = RivaOfflineSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
# stt = ParakeetSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct")
|
||||
|
||||
tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
# tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
|
||||
@@ -13,11 +13,13 @@ from pydantic import BaseModel
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import SegmentedSTTService
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -27,20 +29,21 @@ try:
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.")
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`. Also set NVIDIA_API_KEY env var.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class ParakeetSTTService(STTService):
|
||||
class RivaSTTService(STTService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN_US
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_key: str = None,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
function_id: str = "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
model_name: str = "parakeet-ctc-1.1b-asr",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
@@ -61,7 +64,7 @@ class ParakeetSTTService(STTService):
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
|
||||
self.set_model_name("parakeet-ctc-1.1b-asr")
|
||||
self.set_model_name(model_name)
|
||||
|
||||
metadata = [
|
||||
["function-id", function_id],
|
||||
@@ -196,3 +199,196 @@ class ParakeetSTTService(STTService):
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
class RivaOfflineSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using Fal's Wizper API.
|
||||
|
||||
This service uses Fal's Wizper API to perform speech-to-text transcription on audio
|
||||
segments. It inherits from SegmentedSTTService to handle audio buffering and speech detection.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA_API_KEY.
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate.
|
||||
params: Configuration parameters for Riva.
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Fal's Wizper API.
|
||||
|
||||
Attributes:
|
||||
language: Language of the audio input. Defaults to English.
|
||||
task: Task to perform ('transcribe' or 'translate'). Defaults to 'transcribe'.
|
||||
chunk_level: Level of chunking ('segment'). Defaults to 'segment'.
|
||||
version: Version of Wizper model to use. Defaults to '3'.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
task: str = "transcribe"
|
||||
chunk_level: str = "segment"
|
||||
version: str = "3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str = None,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
function_id: str = "ee8dc628-76de-4acc-8595-1836e7e857bd",
|
||||
model_name: str = "canary-1b-asr",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
self._api_key = api_key
|
||||
self._profanity_filter = False
|
||||
self._automatic_punctuation = False
|
||||
self._no_verbatim_transcripts = False
|
||||
self._language_code = params.language
|
||||
self._boosted_lm_words = None
|
||||
self._boosted_lm_score = 4.0
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
|
||||
self.set_model_name(model_name)
|
||||
|
||||
metadata = [
|
||||
["function-id", function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
self._queue = asyncio.Queue()
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
self._response_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
|
||||
if self._config:
|
||||
return
|
||||
|
||||
# config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
# encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
# model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
# sample_rate_hertz=self.sample_rate,
|
||||
# audio_channel_count=1,
|
||||
# enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,??
|
||||
# ),
|
||||
# interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
self._config = config
|
||||
|
||||
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Transcribe an audio segment
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes in WAV format (already converted by base class).
|
||||
|
||||
Yields:
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
|
||||
Note:
|
||||
The audio is already in WAV format from the SegmentedSTTService.
|
||||
Only non-empty transcriptions are yielded.
|
||||
"""
|
||||
try:
|
||||
response = self._asr_service.offline_recognize(audio, self._config)
|
||||
# response = riva.client.print_offline(response=self._asr_service.offline_recognize(audio, self._config))
|
||||
print(f"_____stt.py * response: {response}")
|
||||
# # Send to Fal directly (audio is already in WAV format from base class)
|
||||
# data_uri = fal_client.encode(audio, "audio/x-wav")
|
||||
# response = await self._fal_client.run(
|
||||
# "fal-ai/wizper",
|
||||
# arguments={"audio_url": data_uri, **self._settings},
|
||||
# )
|
||||
|
||||
if response and "text" in response:
|
||||
text = response["text"].strip()
|
||||
if text: # Only yield non-empty text
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text, "", time_now_iso8601(), Language(self._settings["language"])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Riva Offline STT error: {e}")
|
||||
yield ErrorFrame(f"Riva Offline STT error: {str(e)}")
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
if not self._thread_running:
|
||||
raise StopIteration
|
||||
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
|
||||
return future.result()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
class ParakeetSTTService(RivaSTTService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN_US
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str = None,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
function_id: str = "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
model_name: str = "parakeet-ctc-1.1b-asr",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
server=server,
|
||||
function_id=function_id,
|
||||
model_name=model_name,
|
||||
sample_rate=sample_rate,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`ParakeetSTTService` is deprecated, use `RivaSTTService` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,10 +27,10 @@ except ModuleNotFoundError as e:
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
FASTPITCH_TIMEOUT_SECS = 5
|
||||
RIVA_TTS_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class FastPitchTTSService(TTSService):
|
||||
class RivaTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
@@ -38,11 +38,12 @@ class FastPitchTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_key: str = None,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "English-US.Female-1",
|
||||
voice_id: str = "Magpie-Multilingual.EN-US.Male.Male-1",
|
||||
sample_rate: Optional[int] = None,
|
||||
function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
|
||||
function_id: str = "877104f7-e885-42b9-8de8-f6e4c6303969",
|
||||
model_name: str = "magpie-tts-multilingual",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
@@ -52,7 +53,7 @@ class FastPitchTTSService(TTSService):
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
|
||||
self.set_model_name("fastpitch-hifigan-tts")
|
||||
self.set_model_name(model_name)
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
@@ -100,7 +101,7 @@ class FastPitchTTSService(TTSService):
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
|
||||
resp = await asyncio.wait_for(queue.get(), RIVA_TTS_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
@@ -109,9 +110,46 @@ class FastPitchTTSService(TTSService):
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
|
||||
resp = await asyncio.wait_for(queue.get(), RIVA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
class FastPitchTTSService(RivaTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str = None,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "English-US.Female-1",
|
||||
sample_rate: Optional[int] = None,
|
||||
function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
|
||||
model_name: str = "fastpitch-hifigan-tts",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
voice_id=voice_id,
|
||||
sample_rate=sample_rate,
|
||||
function_id=function_id,
|
||||
model_name=model_name,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FastPitchTTSService` is deprecated, use `RivaTTSService` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user