refactor/simplify NvidiaTTSService synthesis stream shutdown
This commit is contained in:
@@ -269,9 +269,10 @@ class NvidiaTTSService(TTSService):
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
def _create_synthesis_config(self):
|
||||
def _create_synthesis_config(self) -> rtts.RivaSynthesisConfigResponse:
|
||||
"""Fetch and validate synthesis configuration from the server."""
|
||||
if not self._service:
|
||||
return
|
||||
raise RuntimeError("TTS service not initialized")
|
||||
|
||||
try:
|
||||
config = self._service.stub.GetRivaSynthesisConfig(
|
||||
@@ -284,7 +285,9 @@ class NvidiaTTSService(TTSService):
|
||||
logger.error(
|
||||
f"{self} failed to get synthesis config from server (gRPC {status}): {details}"
|
||||
)
|
||||
return None
|
||||
raise RuntimeError(
|
||||
f"{self}: startup failed while fetching synthesis config (gRPC {status})"
|
||||
) from e
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the NVIDIA Nemotron Speech TTS service.
|
||||
@@ -304,7 +307,7 @@ class NvidiaTTSService(TTSService):
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._close_synthesis_stream()
|
||||
await self._abort_synthesis_stream()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the NVIDIA Nemotron Speech TTS service.
|
||||
@@ -313,7 +316,7 @@ class NvidiaTTSService(TTSService):
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._close_synthesis_stream()
|
||||
await self._abort_synthesis_stream()
|
||||
|
||||
def _start_synthesis_stream(self, context_id: str):
|
||||
"""Start a persistent gRPC synthesis stream for the current turn.
|
||||
@@ -412,8 +415,11 @@ class NvidiaTTSService(TTSService):
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
# Ignore stale exceptions from interrupted streams.
|
||||
if self._stream_state is state:
|
||||
# Treat stream exceptions as terminal for this stream. Once
|
||||
# SynthesizeOnline raises, no further reliable audio is expected.
|
||||
# Ignore stale or interruption-driven exceptions to avoid noisy
|
||||
# errors during handoff to a new stream.
|
||||
if self._stream_state is state and not state.stop_event.is_set():
|
||||
await self.push_error(f"{self} synthesis error: {item}")
|
||||
break
|
||||
|
||||
@@ -438,74 +444,12 @@ class NvidiaTTSService(TTSService):
|
||||
"""Signal the active synthesis request generator to close."""
|
||||
state.text_queue.put(None)
|
||||
|
||||
async def _close_synthesis_stream(self):
|
||||
"""Close the active gRPC synthesis stream gracefully.
|
||||
|
||||
Sends a sentinel to end the request generator, waits for the
|
||||
synthesis task to finish producing all remaining audio, then lets
|
||||
the response task drain naturally before cleaning up.
|
||||
"""
|
||||
state = self._stream_state
|
||||
if state is None:
|
||||
return
|
||||
|
||||
self._signal_synthesis_close(state)
|
||||
|
||||
if state.synth_task is not None:
|
||||
try:
|
||||
await state.synth_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
state.synth_task = None
|
||||
|
||||
if state.response_task is not None:
|
||||
try:
|
||||
await state.response_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
state.response_task = None
|
||||
|
||||
if self._stream_state is state:
|
||||
self._stream_state = None
|
||||
|
||||
async def _wait_for_synthesis_close_interruptibly(self, state: _SynthesisStreamState):
|
||||
"""Wait for synthesis close unless interruption preempts this stream."""
|
||||
while True:
|
||||
if self._stream_state is not state or state.stop_event.is_set():
|
||||
# Interruption took ownership of stream shutdown.
|
||||
return
|
||||
|
||||
synth_done = state.synth_task is None or state.synth_task.done()
|
||||
response_done = state.response_task is None or state.response_task.done()
|
||||
if synth_done and response_done:
|
||||
break
|
||||
|
||||
# Poll in short intervals to keep this wait interruptible.
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if state.synth_task is not None:
|
||||
try:
|
||||
await state.synth_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
state.synth_task = None
|
||||
|
||||
if state.response_task is not None:
|
||||
try:
|
||||
await state.response_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
state.response_task = None
|
||||
|
||||
if self._stream_state is state:
|
||||
self._stream_state = None
|
||||
|
||||
async def _abort_synthesis_stream(self):
|
||||
"""Abort the active gRPC synthesis stream immediately.
|
||||
|
||||
Cancels the response task first to stop delivering audio, then
|
||||
drains the text queue and signals the synthesis handler to stop.
|
||||
Unlike ``_close_synthesis_stream``, pending audio is discarded.
|
||||
Pending audio is discarded.
|
||||
"""
|
||||
state = self._stream_state
|
||||
if state is None:
|
||||
@@ -543,7 +487,6 @@ class NvidiaTTSService(TTSService):
|
||||
state = self._stream_state
|
||||
if state is not None:
|
||||
self._signal_synthesis_close(state)
|
||||
await self._wait_for_synthesis_close_interruptibly(state)
|
||||
await super().flush_audio(context_id)
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
@@ -605,7 +548,6 @@ class NvidiaTTSService(TTSService):
|
||||
|
||||
try:
|
||||
assert self._service is not None, "TTS service not initialized"
|
||||
assert self._config is not None, "Synthesis configuration not created"
|
||||
|
||||
# First call for this turn: create audio context and start gRPC stream
|
||||
if not self.audio_context_available(context_id):
|
||||
@@ -629,4 +571,4 @@ class NvidiaTTSService(TTSService):
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
Reference in New Issue
Block a user