Route LLMFullResponseEndFrame through the serialization queue instead of pushing it directly downstream when push_text_frames is enabled. This ensures the frame is emitted only after the audio context is fully drained, preserving correct ordering relative to TTSTextFrames. Previously, the final sentence TTSTextFrame would arrive at the LLMAssistantAggregator after LLMFullResponseEndFrame, causing it to be dropped from the conversation context (especially with RTVI text input where no subsequent interruption would flush the orphaned text).
411 lines
14 KiB
Python
411 lines
14 KiB
Python
#
|
|
# Copyright (c) 2024-2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
"""Tests for frame ordering across TTS service types.
|
|
|
|
Covers three patterns:
|
|
- HTTP TTS services (e.g. CartesiaHttpTTSService): yield audio frames synchronously.
|
|
- WebSocket TTS services without pause (e.g. CartesiaTTSService): deliver audio via
|
|
append_to_audio_context from a background receive loop, no frame-processing pause.
|
|
- WebSocket TTS services with pause (e.g. ElevenLabsTTSService): same delivery
|
|
mechanism, but pause downstream frame processing while audio is in flight.
|
|
|
|
For all three patterns we verify:
|
|
AggregatedTextFrame → TTSStartedFrame → TTSAudioRawFrame (1+) → TTSStoppedFrame → FooFrame
|
|
|
|
repeated for each TTSSpeakFrame, with no cross-group contamination.
|
|
|
|
Also covers LLM response flow with push_text_frames=True (non-word-timestamp TTS):
|
|
verifies TTSTextFrame ordering relative to LLMFullResponseEndFrame.
|
|
"""
|
|
|
|
import asyncio
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
from typing import AsyncGenerator, List, Sequence, Tuple
|
|
|
|
import pytest
|
|
|
|
from pipecat.frames.frames import (
|
|
AggregatedTextFrame,
|
|
DataFrame,
|
|
Frame,
|
|
LLMFullResponseEndFrame,
|
|
LLMFullResponseStartFrame,
|
|
TextFrame,
|
|
TTSAudioRawFrame,
|
|
TTSSpeakFrame,
|
|
TTSStartedFrame,
|
|
TTSStoppedFrame,
|
|
TTSTextFrame,
|
|
)
|
|
from pipecat.services.tts_service import TTSService
|
|
from pipecat.tests.utils import run_test
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test-only frame
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_FAKE_AUDIO = b"\x00\x01" * 320 # 320 bytes of silence
|
|
_SAMPLE_RATE = 16000
|
|
|
|
|
|
@dataclass
|
|
class FooFrame(DataFrame):
|
|
"""Marker frame used to verify relative ordering against TTS audio frames."""
|
|
|
|
label: str = ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mock TTS services
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class MockHttpTTSService(TTSService):
|
|
"""Simulates an HTTP TTS service (e.g. CartesiaHttpTTSService).
|
|
|
|
Audio frames are yielded synchronously from run_tts(), so the audio context
|
|
is fully populated before the next downstream frame is processed.
|
|
TTSStoppedFrame is appended by the base class in on_turn_context_completed()
|
|
once it detects _is_yielding_frames_synchronously is True.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(
|
|
push_start_frame=True,
|
|
push_stop_frames=True,
|
|
push_text_frames=False,
|
|
sample_rate=_SAMPLE_RATE,
|
|
**kwargs,
|
|
)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return False
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
yield TTSAudioRawFrame(
|
|
audio=_FAKE_AUDIO,
|
|
sample_rate=_SAMPLE_RATE,
|
|
num_channels=1,
|
|
context_id=context_id,
|
|
)
|
|
|
|
|
|
class MockHttpPushTextTTSService(TTSService):
|
|
"""Simulates an HTTP TTS service with push_text_frames=True.
|
|
|
|
Used to test that LLMFullResponseEndFrame is emitted after all TTSTextFrames
|
|
when the TTS service generates text frames itself (non-word-timestamp mode).
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(
|
|
push_start_frame=True,
|
|
push_stop_frames=True,
|
|
push_text_frames=True,
|
|
sample_rate=_SAMPLE_RATE,
|
|
**kwargs,
|
|
)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return False
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
yield TTSAudioRawFrame(
|
|
audio=_FAKE_AUDIO,
|
|
sample_rate=_SAMPLE_RATE,
|
|
num_channels=1,
|
|
context_id=context_id,
|
|
)
|
|
|
|
|
|
class MockWebSocketTTSService(TTSService):
|
|
"""Simulates a WebSocket TTS service without frame-processing pause (e.g. CartesiaTTSService).
|
|
|
|
run_tts() is an empty async generator (signals async delivery). A background
|
|
task appends audio frames and the TTSStoppedFrame to the audio context after a
|
|
short delay, mimicking real WebSocket receive-loop behaviour.
|
|
pause_frame_processing=False means downstream frames (FooFrame) are NOT held.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(
|
|
push_start_frame=True,
|
|
push_text_frames=False,
|
|
pause_frame_processing=False,
|
|
sample_rate=_SAMPLE_RATE,
|
|
**kwargs,
|
|
)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return False
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
async def _deliver_audio():
|
|
await asyncio.sleep(0.01)
|
|
await self.append_to_audio_context(
|
|
context_id,
|
|
TTSAudioRawFrame(
|
|
audio=_FAKE_AUDIO,
|
|
sample_rate=_SAMPLE_RATE,
|
|
num_channels=1,
|
|
context_id=context_id,
|
|
),
|
|
)
|
|
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
|
|
await self.remove_audio_context(context_id)
|
|
|
|
self.create_task(_deliver_audio(), name=f"mock_ws_deliver_{context_id}")
|
|
if False:
|
|
yield # make this an async generator without yielding anything
|
|
|
|
|
|
class MockWebSocketPauseTTSService(TTSService):
|
|
"""Simulates a WebSocket TTS service WITH frame-processing pause (e.g. ElevenLabsTTSService).
|
|
|
|
Identical to MockWebSocketTTSService except pause_frame_processing=True.
|
|
on_audio_context_completed() resumes downstream processing once the full
|
|
audio context has been pushed, guaranteeing FooFrame arrives after TTSStoppedFrame.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(
|
|
push_start_frame=True,
|
|
push_text_frames=False,
|
|
pause_frame_processing=True,
|
|
sample_rate=_SAMPLE_RATE,
|
|
**kwargs,
|
|
)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return False
|
|
|
|
async def on_audio_context_completed(self, context_id: str):
|
|
# Resume frame processing after the audio context is fully played out.
|
|
await self._maybe_resume_frame_processing()
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
async def _deliver_audio():
|
|
await asyncio.sleep(0.01)
|
|
await self.append_to_audio_context(
|
|
context_id,
|
|
TTSAudioRawFrame(
|
|
audio=_FAKE_AUDIO,
|
|
sample_rate=_SAMPLE_RATE,
|
|
num_channels=1,
|
|
context_id=context_id,
|
|
),
|
|
)
|
|
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
|
|
await self.remove_audio_context(context_id)
|
|
|
|
self.create_task(_deliver_audio(), name=f"mock_ws_pause_deliver_{context_id}")
|
|
if False:
|
|
yield
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Assertion helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _assert_group_ordering(
|
|
down_frames: Sequence[Frame],
|
|
expected_groups: List[Tuple[str, str]],
|
|
) -> None:
|
|
"""Assert two (or more) TTS+FooFrame groups are in strict order.
|
|
|
|
Args:
|
|
down_frames: All downstream frames received by the test sink.
|
|
expected_groups: List of (tts_text, foo_label) pairs, one per TTSSpeakFrame.
|
|
tts_text is unused in assertions today but included for readability.
|
|
"""
|
|
relevant = [
|
|
f
|
|
for f in down_frames
|
|
if isinstance(
|
|
f, (AggregatedTextFrame, TTSStartedFrame, TTSAudioRawFrame, TTSStoppedFrame, FooFrame)
|
|
)
|
|
]
|
|
|
|
# Locate the FooFrames that delimit groups.
|
|
foo_indices = [i for i, f in enumerate(relevant) if isinstance(f, FooFrame)]
|
|
assert len(foo_indices) == len(expected_groups), (
|
|
f"Expected {len(expected_groups)} FooFrames, got {len(foo_indices)}.\n"
|
|
f"Relevant frames: {[type(f).__name__ for f in relevant]}"
|
|
)
|
|
|
|
# Build groups: everything up to and including each FooFrame.
|
|
groups: List[List[Frame]] = []
|
|
prev = 0
|
|
for idx in foo_indices:
|
|
groups.append(relevant[prev : idx + 1])
|
|
prev = idx + 1
|
|
|
|
for group, (_, foo_label) in zip(groups, expected_groups):
|
|
types = [type(f) for f in group]
|
|
type_names = [t.__name__ for t in types]
|
|
|
|
assert AggregatedTextFrame in types, (
|
|
f"Group {foo_label!r}: missing AggregatedTextFrame. Got: {type_names}"
|
|
)
|
|
assert TTSStartedFrame in types, (
|
|
f"Group {foo_label!r}: missing TTSStartedFrame. Got: {type_names}"
|
|
)
|
|
assert TTSAudioRawFrame in types, (
|
|
f"Group {foo_label!r}: missing TTSAudioRawFrame. Got: {type_names}"
|
|
)
|
|
assert TTSStoppedFrame in types, (
|
|
f"Group {foo_label!r}: missing TTSStoppedFrame. Got: {type_names}"
|
|
)
|
|
|
|
started_idx = types.index(TTSStartedFrame)
|
|
stopped_idx = types.index(TTSStoppedFrame)
|
|
foo_idx = types.index(FooFrame)
|
|
|
|
assert started_idx < stopped_idx, (
|
|
f"Group {foo_label!r}: TTSStartedFrame (pos {started_idx}) must precede "
|
|
f"TTSStoppedFrame (pos {stopped_idx}). Got: {type_names}"
|
|
)
|
|
assert stopped_idx < foo_idx, (
|
|
f"Group {foo_label!r}: TTSStoppedFrame (pos {stopped_idx}) must precede "
|
|
f"FooFrame (pos {foo_idx}). Got: {type_names}"
|
|
)
|
|
|
|
# All frames between TTSStartedFrame and TTSStoppedFrame must be audio.
|
|
mid_types = types[started_idx + 1 : stopped_idx]
|
|
for t in mid_types:
|
|
assert t is TTSAudioRawFrame, (
|
|
f"Group {foo_label!r}: unexpected frame {t.__name__!r} between "
|
|
f"TTSStartedFrame and TTSStoppedFrame. Got: {type_names}"
|
|
)
|
|
|
|
# Check the FooFrame label.
|
|
actual_label = group[foo_idx].label
|
|
assert actual_label == foo_label, (
|
|
f"Expected FooFrame(label={foo_label!r}), got label={actual_label!r}"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_GROUPS = [("test 1", "1"), ("test 2", "2")]
|
|
|
|
|
|
def _make_frames_no_sleep() -> List[Frame]:
|
|
"""Return two TTSSpeakFrame+FooFrame pairs sent back-to-back.
|
|
|
|
Only correct for services that pause downstream processing until the audio
|
|
context is fully consumed (pause_frame_processing=True + on_audio_context_completed).
|
|
"""
|
|
return [
|
|
TTSSpeakFrame(text="test 1", append_to_context=False),
|
|
FooFrame(label="1"),
|
|
TTSSpeakFrame(text="test 2", append_to_context=False),
|
|
FooFrame(label="2"),
|
|
]
|
|
|
|
|
|
def _print_frames_received(frames_received) -> None:
|
|
print("FRAMES RECEIVED:")
|
|
for frame in frames_received[0]:
|
|
print(frame.name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_tts_frame_ordering():
|
|
"""HTTP TTS services yield audio synchronously."""
|
|
tts = MockHttpTTSService()
|
|
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
|
|
|
# only for debugging
|
|
_print_frames_received(frames_received)
|
|
|
|
_assert_group_ordering(frames_received[0], _GROUPS)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_tts_no_pause_frame_ordering():
|
|
"""WebSocket TTS services without pause_frame_processing."""
|
|
tts = MockWebSocketTTSService()
|
|
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
|
_assert_group_ordering(frames_received[0], _GROUPS)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_tts_with_pause_frame_ordering():
|
|
"""WebSocket TTS services with pause_frame_processing=True."""
|
|
tts = MockWebSocketPauseTTSService()
|
|
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
|
_assert_group_ordering(frames_received[0], _GROUPS)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_push_text_llm_response_end_after_tts_text():
|
|
"""LLMFullResponseEndFrame must arrive after all TTSTextFrames.
|
|
|
|
Simulates an LLM response producing multiple sentences through an HTTP TTS
|
|
service with push_text_frames=True. Each sentence is sent as a separate
|
|
TextFrame terminated by a period so the sentence aggregator flushes it.
|
|
The final sentence is flushed by the LLMFullResponseEndFrame itself.
|
|
|
|
Expected downstream ordering:
|
|
LLMFullResponseStartFrame
|
|
... TTSTextFrame (per sentence) ...
|
|
LLMFullResponseEndFrame ← must come AFTER all TTSTextFrames
|
|
"""
|
|
tts = MockHttpPushTextTTSService()
|
|
|
|
# Two sentences: the first ends with a period (triggers aggregator flush),
|
|
# the second does NOT (will be flushed by LLMFullResponseEndFrame).
|
|
frames_to_send = [
|
|
LLMFullResponseStartFrame(),
|
|
TextFrame(text="Hello there. "),
|
|
TextFrame(text="How are you?"),
|
|
LLMFullResponseEndFrame(),
|
|
]
|
|
frames_received = await run_test(tts, frames_to_send=frames_to_send)
|
|
down = frames_received[0]
|
|
|
|
# Collect relevant frame types for ordering check.
|
|
relevant = [
|
|
f
|
|
for f in down
|
|
if isinstance(f, (LLMFullResponseStartFrame, TTSTextFrame, LLMFullResponseEndFrame))
|
|
]
|
|
type_names = [type(f).__name__ for f in relevant]
|
|
|
|
# There should be exactly one LLMFullResponseStartFrame, 2 TTSTextFrames, 1 LLMFullResponseEndFrame.
|
|
tts_text_frames = [f for f in relevant if isinstance(f, TTSTextFrame)]
|
|
end_frames = [f for f in relevant if isinstance(f, LLMFullResponseEndFrame)]
|
|
start_frames = [f for f in relevant if isinstance(f, LLMFullResponseStartFrame)]
|
|
|
|
assert len(start_frames) == 1, (
|
|
f"Expected 1 LLMFullResponseStartFrame, got {len(start_frames)}: {type_names}"
|
|
)
|
|
assert len(tts_text_frames) == 2, (
|
|
f"Expected 2 TTSTextFrames, got {len(tts_text_frames)}: {type_names}"
|
|
)
|
|
assert len(end_frames) == 1, (
|
|
f"Expected 1 LLMFullResponseEndFrame, got {len(end_frames)}: {type_names}"
|
|
)
|
|
|
|
# The critical check: LLMFullResponseEndFrame must come after ALL TTSTextFrames.
|
|
end_idx = relevant.index(end_frames[0])
|
|
last_tts_text_idx = max(relevant.index(f) for f in tts_text_frames)
|
|
|
|
assert last_tts_text_idx < end_idx, (
|
|
f"LLMFullResponseEndFrame (pos {end_idx}) must come after the last "
|
|
f"TTSTextFrame (pos {last_tts_text_idx}). Got: {type_names}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|