NvidiaTTSService: initialize client on StartFrame
Initialize client on StartFrame so errrors are reported within the pipeline.
This commit is contained in:
@@ -25,6 +25,7 @@ from pydantic import BaseModel
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -93,6 +94,7 @@ class NvidiaTTSService(TTSService):
|
||||
|
||||
params = params or NvidiaTTSService.InputParams()
|
||||
|
||||
self._server = server
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
@@ -102,18 +104,8 @@ class NvidiaTTSService(TTSService):
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
self._service = None
|
||||
self._config = None
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
@@ -129,6 +121,39 @@ class NvidiaTTSService(TTSService):
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
def _initialize_client(self):
|
||||
if self._service is not None:
|
||||
return
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
def _create_synthesis_config(self):
|
||||
if not self._service:
|
||||
return
|
||||
|
||||
# warm up the service
|
||||
config = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
return config
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Cartesia TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_synthesis_config()
|
||||
logger.debug(f"Initialized NvidiaTTSService with model: {self.model_name}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
@@ -161,12 +186,15 @@ class NvidiaTTSService(TTSService):
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
assert self._service is not None, "TTS service not initialized"
|
||||
assert self._config is not None, "Synthesis configuration not created"
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
@@ -181,9 +209,12 @@ class NvidiaTTSService(TTSService):
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
Reference in New Issue
Block a user