[inworld] prewarm context on llm response start
This commit is contained in:
1
changelog/4013.changed.md
Normal file
1
changelog/4013.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added context prewarming path for `InworldTTSService` to improve first audio latency
|
||||
@@ -56,6 +56,7 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -653,6 +654,11 @@ class InworldTTSService(WebsocketTTSService):
|
||||
# Track the end time of the last word in the current generation
|
||||
self._generation_end_time = 0.0
|
||||
|
||||
# Context ID that was pre-opened on the server during process_frame
|
||||
# (LLMFullResponseStartFrame) to avoid context creation latency when
|
||||
# enough context for TTS is available.
|
||||
self._prewarmed_context_id: Optional[str] = None
|
||||
|
||||
# Init-only config (not runtime-updatable).
|
||||
self._audio_encoding = encoding
|
||||
self._audio_sample_rate = 0 # Set in start()
|
||||
@@ -726,6 +732,29 @@ class InworldTTSService(WebsocketTTSService):
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and pre-open context on LLM response start.
|
||||
|
||||
Eagerly sends the context configuration to the server when
|
||||
LLMFullResponseStartFrame arrives, so the context is ready by the time
|
||||
enough context for TTS is available. The base class assigns ``_turn_context_id`` before
|
||||
this runs, which is reused for all ``run_tts`` calls within the turn.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
if self._prewarmed_context_id:
|
||||
try:
|
||||
await self._send_close_context(self._prewarmed_context_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"{self}: Failed to close previous prewarmed context: {e}")
|
||||
self._prewarmed_context_id = None
|
||||
try:
|
||||
await self._send_context(self._turn_context_id)
|
||||
self._prewarmed_context_id = self._turn_context_id
|
||||
except Exception as e:
|
||||
logger.warning(f"{self}: Failed to pre-open context: {e}")
|
||||
|
||||
def _calculate_word_times(self, timestamp_info: Dict[str, Any]) -> List[Tuple[str, float]]:
|
||||
"""Calculate word timestamps from Inworld WebSocket API response.
|
||||
|
||||
@@ -887,6 +916,7 @@ class InworldTTSService(WebsocketTTSService):
|
||||
finally:
|
||||
await self.remove_active_audio_context()
|
||||
self._websocket = None
|
||||
self._prewarmed_context_id = None
|
||||
self._cumulative_time = 0.0
|
||||
self._generation_end_time = 0.0
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -1001,9 +1031,16 @@ class InworldTTSService(WebsocketTTSService):
|
||||
async def _send_context(self, context_id: str):
|
||||
"""Send a context to the Inworld WebSocket TTS service.
|
||||
|
||||
Skips the send if this context was already pre-opened on the server
|
||||
(prewarmed during process_frame).
|
||||
|
||||
Args:
|
||||
context_id: The context ID.
|
||||
"""
|
||||
if context_id == self._prewarmed_context_id:
|
||||
self._prewarmed_context_id = None
|
||||
return
|
||||
|
||||
audio_config = {
|
||||
"audioEncoding": self._audio_encoding,
|
||||
"sampleRateHertz": self._audio_sample_rate,
|
||||
|
||||
Reference in New Issue
Block a user