diff --git a/changelog/4167.fixed.2.md b/changelog/4167.fixed.2.md new file mode 100644 index 000000000..6b894cc7f --- /dev/null +++ b/changelog/4167.fixed.2.md @@ -0,0 +1 @@ +- Fixed an issue in `InworldTTSService` where, in cases of fast interruption, we would continue receiving audio from the previous context. \ No newline at end of file diff --git a/changelog/4167.fixed.md b/changelog/4167.fixed.md new file mode 100644 index 000000000..2784d4a41 --- /dev/null +++ b/changelog/4167.fixed.md @@ -0,0 +1 @@ +- Fixed a word timestamp interleaving issue in `InworldTTSService` when processing multiple sentences. diff --git a/src/pipecat/services/inworld/tts.py b/src/pipecat/services/inworld/tts.py index cc0350abf..157130442 100644 --- a/src/pipecat/services/inworld/tts.py +++ b/src/pipecat/services/inworld/tts.py @@ -646,7 +646,7 @@ class InworldTTSService(WebsocketTTSService): super().__init__( push_text_frames=False, - push_stop_frames=True, + push_stop_frames=False, pause_frame_processing=True, sample_rate=sample_rate, aggregate_sentences=aggregate_sentences, @@ -742,21 +742,10 @@ class InworldTTSService(WebsocketTTSService): logger.trace(f"Flushing audio for context {flush_id}") await self._send_flush(flush_id) - async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): - """Push a frame and handle state changes. - - Args: - frame: The frame to push. - direction: The direction to push the frame. - """ - await super().push_frame(frame, direction) - if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)): - logger.trace( - f"{self}: Resetting timestamp tracking due to {type(frame).__name__} - " - f"cumulative_time was {self._cumulative_time}" - ) - self._cumulative_time = 0.0 - self._generation_end_time = 0.0 + def _reset_generation_timing(self): + """Reset the cumulative time and generation end time for a new generation.""" + self._cumulative_time = 0.0 + self._generation_end_time = 0.0 async def on_turn_context_created(self, context_id: str): """Eagerly open the context on the server when a new turn starts. @@ -815,8 +804,6 @@ class InworldTTSService(WebsocketTTSService): except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) self._sent_context_ids.discard(context_id) - self._cumulative_time = 0.0 - self._generation_end_time = 0.0 async def on_turn_context_completed(self): """Close the server-side context at end of turn. @@ -834,10 +821,6 @@ class InworldTTSService(WebsocketTTSService): await self._close_context(context_id) await super().on_audio_context_interrupted(context_id) - async def on_audio_context_completed(self, context_id: str): - """Callback invoked when an audio context has been completed.""" - await self._close_context(context_id) - async def _maybe_push_fallback_text(self, context_id: str): """Push the full text as fallback when no timestamps were received. @@ -966,8 +949,7 @@ class InworldTTSService(WebsocketTTSService): await self.remove_active_audio_context() self._websocket = None self._sent_context_ids.clear() - self._cumulative_time = 0.0 - self._generation_end_time = 0.0 + self._reset_generation_timing() self._context_texts.clear() self._contexts_with_timestamps.clear() await self._call_event_handler("on_disconnected") @@ -1015,10 +997,6 @@ class InworldTTSService(WebsocketTTSService): # Handle context created confirmation if "contextCreated" in result: logger.trace(f"{self}: Context created on server: {ctx_id}") - # If the context isn't available recreate it (handles race conditions during interruption recovery). - elif ctx_id and not self.audio_context_available(ctx_id): - logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}") - await self.create_audio_context(ctx_id) # Process audio chunk audio_chunk = result.get("audioChunk", {}) @@ -1053,7 +1031,7 @@ class InworldTTSService(WebsocketTTSService): # Handle context closed - context no longer exists on server if "contextClosed" in result: - logger.trace(f"{self}: Context closed on server: {ctx_id}") + logger.debug(f"{self}: Context closed on server: {ctx_id}") await self._maybe_push_fallback_text(ctx_id) await self.stop_ttfb_metrics() await self.append_to_audio_context(ctx_id, TTSStoppedFrame(context_id=ctx_id)) @@ -1166,7 +1144,7 @@ class InworldTTSService(WebsocketTTSService): Returns: An asynchronous generator of frames. """ - logger.debug(f"{self}: Generating WebSocket TTS [{text}]") + logger.debug(f"{self}: Generating WebSocket TTS [{text}, for context: {context_id}]") try: if not self._websocket or self._websocket.state is State.CLOSED: @@ -1174,6 +1152,7 @@ class InworldTTSService(WebsocketTTSService): try: if not self.audio_context_available(context_id): + self._reset_generation_timing() await self.create_audio_context(context_id) await self.start_ttfb_metrics() yield TTSStartedFrame(context_id=context_id) diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index a65a5e244..1cd68540d 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -1223,7 +1223,7 @@ class TTSService(AIService): logger.trace(f"{self} appending audio {frame} to audio context {context_id}") await self._audio_contexts[context_id].put(frame) else: - logger.warning(f"{self} unable to append audio to context {context_id}") + logger.debug(f"{self} unable to append audio to context {context_id}") async def remove_audio_context(self, context_id: str): """Remove an existing audio context.