NvidiaTTSService: initialize client on StartFrame

Initialize client on StartFrame so errrors are reported within the pipeline.
This commit is contained in:
Aleix Conchillo Flaqué
2026-01-19 20:30:17 -08:00
parent 024809b39a
commit 9a718ded1e

View File

@@ -25,6 +25,7 @@ from pydantic import BaseModel
from pipecat.frames.frames import (
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -93,6 +94,7 @@ class NvidiaTTSService(TTSService):
params = params or NvidiaTTSService.InputParams()
self._server = server
self._api_key = api_key
self._voice_id = voice_id
self._language_code = params.language
@@ -102,18 +104,8 @@ class NvidiaTTSService(TTSService):
self.set_model_name(model_function_map.get("model_name"))
self.set_voice(voice_id)
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {api_key}"],
]
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
# warm up the service
config_response = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
self._service = None
self._config = None
async def set_model(self, model: str):
"""Attempt to set the TTS model.
@@ -129,6 +121,39 @@ class NvidiaTTSService(TTSService):
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
)
def _initialize_client(self):
if self._service is not None:
return
metadata = [
["function-id", self._function_id],
["authorization", f"Bearer {self._api_key}"],
]
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
def _create_synthesis_config(self):
if not self._service:
return
# warm up the service
config = self._service.stub.GetRivaSynthesisConfig(
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
return config
async def start(self, frame: StartFrame):
"""Start the Cartesia TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._initialize_client()
self._config = self._create_synthesis_config()
logger.debug(f"Initialized NvidiaTTSService with model: {self.model_name}")
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using NVIDIA Riva TTS.
@@ -161,12 +186,15 @@ class NvidiaTTSService(TTSService):
logger.error(f"{self} exception: {e}")
add_response(None)
await self.start_ttfb_metrics()
yield TTSStartedFrame()
logger.debug(f"{self}: Generating TTS [{text}]")
try:
assert self._service is not None, "TTS service not initialized"
assert self._config is not None, "Synthesis configuration not created"
await self.start_ttfb_metrics()
yield TTSStartedFrame()
logger.debug(f"{self}: Generating TTS [{text}]")
queue = asyncio.Queue()
await asyncio.to_thread(read_audio_responses, queue)
@@ -181,9 +209,12 @@ class NvidiaTTSService(TTSService):
)
yield frame
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()
except asyncio.TimeoutError:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")