Merge pull request #4127 from pipecat-ai/mb/tts-text-frame-ordering
Fix LLMFullResponseEndFrame racing ahead of final TTSTextFrame
This commit is contained in:
1
changelog/4127.fixed.md
Normal file
1
changelog/4127.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed the final sentence being dropped from the conversation context when using RTVI text input with non-word-timestamp TTS services. The `LLMFullResponseEndFrame` was racing ahead of the last `TTSTextFrame`, causing the `LLMAssistantAggregator` to finalize the context before the final sentence arrived.
|
||||
@@ -750,7 +750,11 @@ class TTSService(AIService):
|
||||
self._processing_text = False
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
# Route through the serialization queue so the frame is
|
||||
# emitted only after the audio context has been fully
|
||||
# drained (including the final TTSTextFrame). Pushing
|
||||
# directly would let it race ahead of queued text frames.
|
||||
await self._serialization_queue.put(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@ For all three patterns we verify:
|
||||
AggregatedTextFrame → TTSStartedFrame → TTSAudioRawFrame (1+) → TTSStoppedFrame → FooFrame
|
||||
|
||||
repeated for each TTSSpeakFrame, with no cross-group contamination.
|
||||
|
||||
Also covers LLM response flow with push_text_frames=True (non-word-timestamp TTS):
|
||||
verifies TTSTextFrame ordering relative to LLMFullResponseEndFrame.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -30,10 +33,14 @@ from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
DataFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TextFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.tests.utils import run_test
|
||||
@@ -88,6 +95,34 @@ class MockHttpTTSService(TTSService):
|
||||
)
|
||||
|
||||
|
||||
class MockHttpPushTextTTSService(TTSService):
|
||||
"""Simulates an HTTP TTS service with push_text_frames=True.
|
||||
|
||||
Used to test that LLMFullResponseEndFrame is emitted after all TTSTextFrames
|
||||
when the TTS service generates text frames itself (non-word-timestamp mode).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
push_text_frames=True,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
yield TTSAudioRawFrame(
|
||||
audio=_FAKE_AUDIO,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
|
||||
class MockWebSocketTTSService(TTSService):
|
||||
"""Simulates a WebSocket TTS service without frame-processing pause (e.g. CartesiaTTSService).
|
||||
|
||||
@@ -311,5 +346,65 @@ async def test_websocket_tts_with_pause_frame_ordering():
|
||||
_assert_group_ordering(frames_received[0], _GROUPS)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_push_text_llm_response_end_after_tts_text():
|
||||
"""LLMFullResponseEndFrame must arrive after all TTSTextFrames.
|
||||
|
||||
Simulates an LLM response producing multiple sentences through an HTTP TTS
|
||||
service with push_text_frames=True. Each sentence is sent as a separate
|
||||
TextFrame terminated by a period so the sentence aggregator flushes it.
|
||||
The final sentence is flushed by the LLMFullResponseEndFrame itself.
|
||||
|
||||
Expected downstream ordering:
|
||||
LLMFullResponseStartFrame
|
||||
... TTSTextFrame (per sentence) ...
|
||||
LLMFullResponseEndFrame ← must come AFTER all TTSTextFrames
|
||||
"""
|
||||
tts = MockHttpPushTextTTSService()
|
||||
|
||||
# Two sentences: the first ends with a period (triggers aggregator flush),
|
||||
# the second does NOT (will be flushed by LLMFullResponseEndFrame).
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello there. "),
|
||||
TextFrame(text="How are you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
frames_received = await run_test(tts, frames_to_send=frames_to_send)
|
||||
down = frames_received[0]
|
||||
|
||||
# Collect relevant frame types for ordering check.
|
||||
relevant = [
|
||||
f
|
||||
for f in down
|
||||
if isinstance(f, (LLMFullResponseStartFrame, TTSTextFrame, LLMFullResponseEndFrame))
|
||||
]
|
||||
type_names = [type(f).__name__ for f in relevant]
|
||||
|
||||
# There should be exactly one LLMFullResponseStartFrame, 2 TTSTextFrames, 1 LLMFullResponseEndFrame.
|
||||
tts_text_frames = [f for f in relevant if isinstance(f, TTSTextFrame)]
|
||||
end_frames = [f for f in relevant if isinstance(f, LLMFullResponseEndFrame)]
|
||||
start_frames = [f for f in relevant if isinstance(f, LLMFullResponseStartFrame)]
|
||||
|
||||
assert len(start_frames) == 1, (
|
||||
f"Expected 1 LLMFullResponseStartFrame, got {len(start_frames)}: {type_names}"
|
||||
)
|
||||
assert len(tts_text_frames) == 2, (
|
||||
f"Expected 2 TTSTextFrames, got {len(tts_text_frames)}: {type_names}"
|
||||
)
|
||||
assert len(end_frames) == 1, (
|
||||
f"Expected 1 LLMFullResponseEndFrame, got {len(end_frames)}: {type_names}"
|
||||
)
|
||||
|
||||
# The critical check: LLMFullResponseEndFrame must come after ALL TTSTextFrames.
|
||||
end_idx = relevant.index(end_frames[0])
|
||||
last_tts_text_idx = max(relevant.index(f) for f in tts_text_frames)
|
||||
|
||||
assert last_tts_text_idx < end_idx, (
|
||||
f"LLMFullResponseEndFrame (pos {end_idx}) must come after the last "
|
||||
f"TTSTextFrame (pos {last_tts_text_idx}). Got: {type_names}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user