Merge pull request #3300 from omChauhanDev/nvidia-expose-use_ssl-param
exposed use_ssl param in nvidia services
This commit is contained in:
1
changelog/3300.added.md
Normal file
1
changelog/3300.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `use_ssl` parameter to `NvidiaSTTService`, `NvidiaSegmentedSTTService` and `NvidiaTTSService`.
|
||||
@@ -116,6 +116,7 @@ class NvidiaSTTService(STTService):
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
use_ssl: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva STT service.
|
||||
@@ -126,6 +127,7 @@ class NvidiaSTTService(STTService):
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for NVIDIA Riva.
|
||||
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
@@ -133,6 +135,7 @@ class NvidiaSTTService(STTService):
|
||||
params = params or NvidiaSTTService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._use_ssl = use_ssl
|
||||
self._profanity_filter = False
|
||||
self._automatic_punctuation = True
|
||||
self._no_verbatim_transcripts = False
|
||||
@@ -163,7 +166,7 @@ class NvidiaSTTService(STTService):
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
@@ -421,6 +424,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
use_ssl: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva segmented STT service.
|
||||
@@ -431,6 +435,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
|
||||
params: Additional configuration parameters for NVIDIA Riva
|
||||
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
@@ -443,6 +448,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
# Initialize NVIDIA Riva settings
|
||||
self._api_key = api_key
|
||||
self._server = server
|
||||
self._use_ssl = use_ssl
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
self._model_name = model_function_map.get("model_name")
|
||||
|
||||
@@ -494,7 +500,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
]
|
||||
|
||||
# Create authenticated client
|
||||
auth = riva.client.Auth(None, True, self._server, metadata)
|
||||
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
@@ -74,6 +74,7 @@ class NvidiaTTSService(TTSService):
|
||||
"model_name": "magpie-tts-multilingual",
|
||||
},
|
||||
params: Optional[InputParams] = None,
|
||||
use_ssl: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva TTS service.
|
||||
@@ -85,6 +86,7 @@ class NvidiaTTSService(TTSService):
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for the TTS model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
@@ -96,7 +98,7 @@ class NvidiaTTSService(TTSService):
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self._use_ssl = use_ssl
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
@@ -104,7 +106,7 @@ class NvidiaTTSService(TTSService):
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user