refactor/simplify NvidiaTTSService synthesis stream shutdown

This commit is contained in:
sathwika
2026-04-13 14:35:17 +05:30
parent 746fadc2b5
commit f45a410f56

View File

@@ -269,9 +269,10 @@ class NvidiaTTSService(TTSService):
self._service = riva.client.SpeechSynthesisService(auth)
def _create_synthesis_config(self):
def _create_synthesis_config(self) -> rtts.RivaSynthesisConfigResponse:
"""Fetch and validate synthesis configuration from the server."""
if not self._service:
return
raise RuntimeError("TTS service not initialized")
try:
config = self._service.stub.GetRivaSynthesisConfig(
@@ -284,7 +285,9 @@ class NvidiaTTSService(TTSService):
logger.error(
f"{self} failed to get synthesis config from server (gRPC {status}): {details}"
)
return None
raise RuntimeError(
f"{self}: startup failed while fetching synthesis config (gRPC {status})"
) from e
async def start(self, frame: StartFrame):
"""Start the NVIDIA Nemotron Speech TTS service.
@@ -304,7 +307,7 @@ class NvidiaTTSService(TTSService):
frame: The end frame.
"""
await super().stop(frame)
await self._close_synthesis_stream()
await self._abort_synthesis_stream()
async def cancel(self, frame: CancelFrame):
"""Cancel the NVIDIA Nemotron Speech TTS service.
@@ -313,7 +316,7 @@ class NvidiaTTSService(TTSService):
frame: The cancel frame.
"""
await super().cancel(frame)
await self._close_synthesis_stream()
await self._abort_synthesis_stream()
def _start_synthesis_stream(self, context_id: str):
"""Start a persistent gRPC synthesis stream for the current turn.
@@ -412,8 +415,11 @@ class NvidiaTTSService(TTSService):
if item is None:
break
if isinstance(item, Exception):
# Ignore stale exceptions from interrupted streams.
if self._stream_state is state:
# Treat stream exceptions as terminal for this stream. Once
# SynthesizeOnline raises, no further reliable audio is expected.
# Ignore stale or interruption-driven exceptions to avoid noisy
# errors during handoff to a new stream.
if self._stream_state is state and not state.stop_event.is_set():
await self.push_error(f"{self} synthesis error: {item}")
break
@@ -438,74 +444,12 @@ class NvidiaTTSService(TTSService):
"""Signal the active synthesis request generator to close."""
state.text_queue.put(None)
async def _close_synthesis_stream(self):
"""Close the active gRPC synthesis stream gracefully.
Sends a sentinel to end the request generator, waits for the
synthesis task to finish producing all remaining audio, then lets
the response task drain naturally before cleaning up.
"""
state = self._stream_state
if state is None:
return
self._signal_synthesis_close(state)
if state.synth_task is not None:
try:
await state.synth_task
except asyncio.CancelledError:
pass
state.synth_task = None
if state.response_task is not None:
try:
await state.response_task
except asyncio.CancelledError:
pass
state.response_task = None
if self._stream_state is state:
self._stream_state = None
async def _wait_for_synthesis_close_interruptibly(self, state: _SynthesisStreamState):
"""Wait for synthesis close unless interruption preempts this stream."""
while True:
if self._stream_state is not state or state.stop_event.is_set():
# Interruption took ownership of stream shutdown.
return
synth_done = state.synth_task is None or state.synth_task.done()
response_done = state.response_task is None or state.response_task.done()
if synth_done and response_done:
break
# Poll in short intervals to keep this wait interruptible.
await asyncio.sleep(0.05)
if state.synth_task is not None:
try:
await state.synth_task
except asyncio.CancelledError:
pass
state.synth_task = None
if state.response_task is not None:
try:
await state.response_task
except asyncio.CancelledError:
pass
state.response_task = None
if self._stream_state is state:
self._stream_state = None
async def _abort_synthesis_stream(self):
"""Abort the active gRPC synthesis stream immediately.
Cancels the response task first to stop delivering audio, then
drains the text queue and signals the synthesis handler to stop.
Unlike ``_close_synthesis_stream``, pending audio is discarded.
Pending audio is discarded.
"""
state = self._stream_state
if state is None:
@@ -543,7 +487,6 @@ class NvidiaTTSService(TTSService):
state = self._stream_state
if state is not None:
self._signal_synthesis_close(state)
await self._wait_for_synthesis_close_interruptibly(state)
await super().flush_audio(context_id)
async def on_audio_context_interrupted(self, context_id: str):
@@ -605,7 +548,6 @@ class NvidiaTTSService(TTSService):
try:
assert self._service is not None, "TTS service not initialized"
assert self._config is not None, "Synthesis configuration not created"
# First call for this turn: create audio context and start gRPC stream
if not self.audio_context_available(context_id):
@@ -629,4 +571,4 @@ class NvidiaTTSService(TTSService):
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
yield ErrorFrame(error=f"{self} error: {e}")