Merge pull request #4167 from pipecat-ai/filipi/inworld_improvements

InworldTTSService improvements.
This commit is contained in:
Filipi da Silva Fuchter
2026-03-27 11:15:14 -04:00
committed by GitHub
4 changed files with 12 additions and 31 deletions

View File

@@ -0,0 +1 @@
- Fixed an issue in `InworldTTSService` where, in cases of fast interruption, we would continue receiving audio from the previous context.

1
changelog/4167.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Fixed a word timestamp interleaving issue in `InworldTTSService` when processing multiple sentences.

View File

@@ -646,7 +646,7 @@ class InworldTTSService(WebsocketTTSService):
super().__init__(
push_text_frames=False,
push_stop_frames=True,
push_stop_frames=False,
pause_frame_processing=True,
sample_rate=sample_rate,
aggregate_sentences=aggregate_sentences,
@@ -742,21 +742,10 @@ class InworldTTSService(WebsocketTTSService):
logger.trace(f"Flushing audio for context {flush_id}")
await self._send_flush(flush_id)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame and handle state changes.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
logger.trace(
f"{self}: Resetting timestamp tracking due to {type(frame).__name__} - "
f"cumulative_time was {self._cumulative_time}"
)
self._cumulative_time = 0.0
self._generation_end_time = 0.0
def _reset_generation_timing(self):
"""Reset the cumulative time and generation end time for a new generation."""
self._cumulative_time = 0.0
self._generation_end_time = 0.0
async def on_turn_context_created(self, context_id: str):
"""Eagerly open the context on the server when a new turn starts.
@@ -815,8 +804,6 @@ class InworldTTSService(WebsocketTTSService):
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._sent_context_ids.discard(context_id)
self._cumulative_time = 0.0
self._generation_end_time = 0.0
async def on_turn_context_completed(self):
"""Close the server-side context at end of turn.
@@ -834,10 +821,6 @@ class InworldTTSService(WebsocketTTSService):
await self._close_context(context_id)
await super().on_audio_context_interrupted(context_id)
async def on_audio_context_completed(self, context_id: str):
"""Callback invoked when an audio context has been completed."""
await self._close_context(context_id)
async def _maybe_push_fallback_text(self, context_id: str):
"""Push the full text as fallback when no timestamps were received.
@@ -966,8 +949,7 @@ class InworldTTSService(WebsocketTTSService):
await self.remove_active_audio_context()
self._websocket = None
self._sent_context_ids.clear()
self._cumulative_time = 0.0
self._generation_end_time = 0.0
self._reset_generation_timing()
self._context_texts.clear()
self._contexts_with_timestamps.clear()
await self._call_event_handler("on_disconnected")
@@ -1015,10 +997,6 @@ class InworldTTSService(WebsocketTTSService):
# Handle context created confirmation
if "contextCreated" in result:
logger.trace(f"{self}: Context created on server: {ctx_id}")
# If the context isn't available recreate it (handles race conditions during interruption recovery).
elif ctx_id and not self.audio_context_available(ctx_id):
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
await self.create_audio_context(ctx_id)
# Process audio chunk
audio_chunk = result.get("audioChunk", {})
@@ -1053,7 +1031,7 @@ class InworldTTSService(WebsocketTTSService):
# Handle context closed - context no longer exists on server
if "contextClosed" in result:
logger.trace(f"{self}: Context closed on server: {ctx_id}")
logger.debug(f"{self}: Context closed on server: {ctx_id}")
await self._maybe_push_fallback_text(ctx_id)
await self.stop_ttfb_metrics()
await self.append_to_audio_context(ctx_id, TTSStoppedFrame(context_id=ctx_id))
@@ -1166,7 +1144,7 @@ class InworldTTSService(WebsocketTTSService):
Returns:
An asynchronous generator of frames.
"""
logger.debug(f"{self}: Generating WebSocket TTS [{text}]")
logger.debug(f"{self}: Generating WebSocket TTS [{text}, for context: {context_id}]")
try:
if not self._websocket or self._websocket.state is State.CLOSED:
@@ -1174,6 +1152,7 @@ class InworldTTSService(WebsocketTTSService):
try:
if not self.audio_context_available(context_id):
self._reset_generation_timing()
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)

View File

@@ -1223,7 +1223,7 @@ class TTSService(AIService):
logger.trace(f"{self} appending audio {frame} to audio context {context_id}")
await self._audio_contexts[context_id].put(frame)
else:
logger.warning(f"{self} unable to append audio to context {context_id}")
logger.debug(f"{self} unable to append audio to context {context_id}")
async def remove_audio_context(self, context_id: str):
"""Remove an existing audio context.