Merge pull request #3300 from omChauhanDev/nvidia-expose-use_ssl-param

exposed use_ssl param in nvidia services
This commit is contained in:
Mark Backman
2025-12-29 09:18:26 -05:00
committed by GitHub
3 changed files with 13 additions and 4 deletions

1
changelog/3300.added.md Normal file
View File

@@ -0,0 +1 @@
- Added `use_ssl` parameter to `NvidiaSTTService`, `NvidiaSegmentedSTTService` and `NvidiaTTSService`.

View File

@@ -116,6 +116,7 @@ class NvidiaSTTService(STTService):
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
use_ssl: bool = True,
**kwargs,
):
"""Initialize the NVIDIA Riva STT service.
@@ -126,6 +127,7 @@ class NvidiaSTTService(STTService):
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
params: Additional configuration parameters for NVIDIA Riva.
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
**kwargs: Additional arguments passed to STTService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
@@ -133,6 +135,7 @@ class NvidiaSTTService(STTService):
params = params or NvidiaSTTService.InputParams()
self._api_key = api_key
self._use_ssl = use_ssl
self._profanity_filter = False
self._automatic_punctuation = True
self._no_verbatim_transcripts = False
@@ -163,7 +166,7 @@ class NvidiaSTTService(STTService):
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
self._asr_service = riva.client.ASRService(auth)
@@ -421,6 +424,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
},
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
use_ssl: bool = True,
**kwargs,
):
"""Initialize the NVIDIA Riva segmented STT service.
@@ -431,6 +435,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
params: Additional configuration parameters for NVIDIA Riva
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
**kwargs: Additional arguments passed to SegmentedSTTService
"""
super().__init__(sample_rate=sample_rate, **kwargs)
@@ -443,6 +448,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
# Initialize NVIDIA Riva settings
self._api_key = api_key
self._server = server
self._use_ssl = use_ssl
self._function_id = model_function_map.get("function_id")
self._model_name = model_function_map.get("model_name")
@@ -494,7 +500,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
]
# Create authenticated client
auth = riva.client.Auth(None, True, self._server, metadata)
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
self._asr_service = riva.client.ASRService(auth)
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")

View File

@@ -74,6 +74,7 @@ class NvidiaTTSService(TTSService):
"model_name": "magpie-tts-multilingual",
},
params: Optional[InputParams] = None,
use_ssl: bool = True,
**kwargs,
):
"""Initialize the NVIDIA Riva TTS service.
@@ -85,6 +86,7 @@ class NvidiaTTSService(TTSService):
sample_rate: Audio sample rate. If None, uses service default.
model_function_map: Dictionary containing function_id and model_name for the TTS model.
params: Additional configuration parameters for TTS synthesis.
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
@@ -96,7 +98,7 @@ class NvidiaTTSService(TTSService):
self._language_code = params.language
self._quality = params.quality
self._function_id = model_function_map.get("function_id")
self._use_ssl = use_ssl
self.set_model_name(model_function_map.get("model_name"))
self.set_voice(voice_id)
@@ -104,7 +106,7 @@ class NvidiaTTSService(TTSService):
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, True, server, metadata)
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)