Merge pull request #4167 from pipecat-ai/filipi/inworld_improvements
InworldTTSService improvements.
This commit is contained in:
1
changelog/4167.fixed.2.md
Normal file
1
changelog/4167.fixed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an issue in `InworldTTSService` where, in cases of fast interruption, we would continue receiving audio from the previous context.
|
||||
1
changelog/4167.fixed.md
Normal file
1
changelog/4167.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed a word timestamp interleaving issue in `InworldTTSService` when processing multiple sentences.
|
||||
@@ -646,7 +646,7 @@ class InworldTTSService(WebsocketTTSService):
|
||||
|
||||
super().__init__(
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
push_stop_frames=False,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
@@ -742,21 +742,10 @@ class InworldTTSService(WebsocketTTSService):
|
||||
logger.trace(f"Flushing audio for context {flush_id}")
|
||||
await self._send_flush(flush_id)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
logger.trace(
|
||||
f"{self}: Resetting timestamp tracking due to {type(frame).__name__} - "
|
||||
f"cumulative_time was {self._cumulative_time}"
|
||||
)
|
||||
self._cumulative_time = 0.0
|
||||
self._generation_end_time = 0.0
|
||||
def _reset_generation_timing(self):
|
||||
"""Reset the cumulative time and generation end time for a new generation."""
|
||||
self._cumulative_time = 0.0
|
||||
self._generation_end_time = 0.0
|
||||
|
||||
async def on_turn_context_created(self, context_id: str):
|
||||
"""Eagerly open the context on the server when a new turn starts.
|
||||
@@ -815,8 +804,6 @@ class InworldTTSService(WebsocketTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._sent_context_ids.discard(context_id)
|
||||
self._cumulative_time = 0.0
|
||||
self._generation_end_time = 0.0
|
||||
|
||||
async def on_turn_context_completed(self):
|
||||
"""Close the server-side context at end of turn.
|
||||
@@ -834,10 +821,6 @@ class InworldTTSService(WebsocketTTSService):
|
||||
await self._close_context(context_id)
|
||||
await super().on_audio_context_interrupted(context_id)
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
"""Callback invoked when an audio context has been completed."""
|
||||
await self._close_context(context_id)
|
||||
|
||||
async def _maybe_push_fallback_text(self, context_id: str):
|
||||
"""Push the full text as fallback when no timestamps were received.
|
||||
|
||||
@@ -966,8 +949,7 @@ class InworldTTSService(WebsocketTTSService):
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
self._sent_context_ids.clear()
|
||||
self._cumulative_time = 0.0
|
||||
self._generation_end_time = 0.0
|
||||
self._reset_generation_timing()
|
||||
self._context_texts.clear()
|
||||
self._contexts_with_timestamps.clear()
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -1015,10 +997,6 @@ class InworldTTSService(WebsocketTTSService):
|
||||
# Handle context created confirmation
|
||||
if "contextCreated" in result:
|
||||
logger.trace(f"{self}: Context created on server: {ctx_id}")
|
||||
# If the context isn't available recreate it (handles race conditions during interruption recovery).
|
||||
elif ctx_id and not self.audio_context_available(ctx_id):
|
||||
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
|
||||
await self.create_audio_context(ctx_id)
|
||||
|
||||
# Process audio chunk
|
||||
audio_chunk = result.get("audioChunk", {})
|
||||
@@ -1053,7 +1031,7 @@ class InworldTTSService(WebsocketTTSService):
|
||||
|
||||
# Handle context closed - context no longer exists on server
|
||||
if "contextClosed" in result:
|
||||
logger.trace(f"{self}: Context closed on server: {ctx_id}")
|
||||
logger.debug(f"{self}: Context closed on server: {ctx_id}")
|
||||
await self._maybe_push_fallback_text(ctx_id)
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.append_to_audio_context(ctx_id, TTSStoppedFrame(context_id=ctx_id))
|
||||
@@ -1166,7 +1144,7 @@ class InworldTTSService(WebsocketTTSService):
|
||||
Returns:
|
||||
An asynchronous generator of frames.
|
||||
"""
|
||||
logger.debug(f"{self}: Generating WebSocket TTS [{text}]")
|
||||
logger.debug(f"{self}: Generating WebSocket TTS [{text}, for context: {context_id}]")
|
||||
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
@@ -1174,6 +1152,7 @@ class InworldTTSService(WebsocketTTSService):
|
||||
|
||||
try:
|
||||
if not self.audio_context_available(context_id):
|
||||
self._reset_generation_timing()
|
||||
await self.create_audio_context(context_id)
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
@@ -1223,7 +1223,7 @@ class TTSService(AIService):
|
||||
logger.trace(f"{self} appending audio {frame} to audio context {context_id}")
|
||||
await self._audio_contexts[context_id].put(frame)
|
||||
else:
|
||||
logger.warning(f"{self} unable to append audio to context {context_id}")
|
||||
logger.debug(f"{self} unable to append audio to context {context_id}")
|
||||
|
||||
async def remove_audio_context(self, context_id: str):
|
||||
"""Remove an existing audio context.
|
||||
|
||||
Reference in New Issue
Block a user