Automated via ruff UP006, UP007, UP035, UP045 rules (target: py311): - Replace `typing.List`, `Dict`, `Tuple`, `Set`, `FrozenSet`, `Type` with their built-in equivalents (`list`, `dict`, `tuple`, etc.) - Replace `typing.Optional[X]` with `X | None` - Replace `typing.Union[X, Y]` with `X | Y` - Move `Mapping`, `Sequence`, `Callable`, `Awaitable`, `MutableMapping`, `MutableSequence`, `Iterator`, `AsyncIterator`, `AsyncGenerator` imports from `typing` to `collections.abc` - Remove now-unused `typing` imports - Add `from __future__ import annotations` to 5 files that use forward-reference strings in `X | "Y"` annotations
411 lines
14 KiB
Python
411 lines
14 KiB
Python
#
|
|
# Copyright (c) 2024-2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
"""Tests for frame ordering across TTS service types.
|
|
|
|
Covers three patterns:
|
|
- HTTP TTS services (e.g. CartesiaHttpTTSService): yield audio frames synchronously.
|
|
- WebSocket TTS services without pause (e.g. CartesiaTTSService): deliver audio via
|
|
append_to_audio_context from a background receive loop, no frame-processing pause.
|
|
- WebSocket TTS services with pause (e.g. ElevenLabsTTSService): same delivery
|
|
mechanism, but pause downstream frame processing while audio is in flight.
|
|
|
|
For all three patterns we verify:
|
|
AggregatedTextFrame → TTSStartedFrame → TTSAudioRawFrame (1+) → TTSStoppedFrame → FooFrame
|
|
|
|
repeated for each TTSSpeakFrame, with no cross-group contamination.
|
|
|
|
Also covers LLM response flow with push_text_frames=True (non-word-timestamp TTS):
|
|
verifies TTSTextFrame ordering relative to LLMFullResponseEndFrame.
|
|
"""
|
|
|
|
import asyncio
|
|
import unittest
|
|
from collections.abc import AsyncGenerator, Sequence
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
|
|
from pipecat.frames.frames import (
|
|
AggregatedTextFrame,
|
|
DataFrame,
|
|
Frame,
|
|
LLMFullResponseEndFrame,
|
|
LLMFullResponseStartFrame,
|
|
TextFrame,
|
|
TTSAudioRawFrame,
|
|
TTSSpeakFrame,
|
|
TTSStartedFrame,
|
|
TTSStoppedFrame,
|
|
TTSTextFrame,
|
|
)
|
|
from pipecat.services.tts_service import TTSService
|
|
from pipecat.tests.utils import run_test
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test-only frame
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_FAKE_AUDIO = b"\x00\x01" * 320 # 320 bytes of silence
|
|
_SAMPLE_RATE = 16000
|
|
|
|
|
|
@dataclass
|
|
class FooFrame(DataFrame):
|
|
"""Marker frame used to verify relative ordering against TTS audio frames."""
|
|
|
|
label: str = ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mock TTS services
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class MockHttpTTSService(TTSService):
|
|
"""Simulates an HTTP TTS service (e.g. CartesiaHttpTTSService).
|
|
|
|
Audio frames are yielded synchronously from run_tts(), so the audio context
|
|
is fully populated before the next downstream frame is processed.
|
|
TTSStoppedFrame is appended by the base class in on_turn_context_completed()
|
|
once it detects _is_yielding_frames_synchronously is True.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(
|
|
push_start_frame=True,
|
|
push_stop_frames=True,
|
|
push_text_frames=False,
|
|
sample_rate=_SAMPLE_RATE,
|
|
**kwargs,
|
|
)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return False
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
yield TTSAudioRawFrame(
|
|
audio=_FAKE_AUDIO,
|
|
sample_rate=_SAMPLE_RATE,
|
|
num_channels=1,
|
|
context_id=context_id,
|
|
)
|
|
|
|
|
|
class MockHttpPushTextTTSService(TTSService):
|
|
"""Simulates an HTTP TTS service with push_text_frames=True.
|
|
|
|
Used to test that LLMFullResponseEndFrame is emitted after all TTSTextFrames
|
|
when the TTS service generates text frames itself (non-word-timestamp mode).
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(
|
|
push_start_frame=True,
|
|
push_stop_frames=True,
|
|
push_text_frames=True,
|
|
sample_rate=_SAMPLE_RATE,
|
|
**kwargs,
|
|
)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return False
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
yield TTSAudioRawFrame(
|
|
audio=_FAKE_AUDIO,
|
|
sample_rate=_SAMPLE_RATE,
|
|
num_channels=1,
|
|
context_id=context_id,
|
|
)
|
|
|
|
|
|
class MockWebSocketTTSService(TTSService):
|
|
"""Simulates a WebSocket TTS service without frame-processing pause (e.g. CartesiaTTSService).
|
|
|
|
run_tts() is an empty async generator (signals async delivery). A background
|
|
task appends audio frames and the TTSStoppedFrame to the audio context after a
|
|
short delay, mimicking real WebSocket receive-loop behaviour.
|
|
pause_frame_processing=False means downstream frames (FooFrame) are NOT held.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(
|
|
push_start_frame=True,
|
|
push_text_frames=False,
|
|
pause_frame_processing=False,
|
|
sample_rate=_SAMPLE_RATE,
|
|
**kwargs,
|
|
)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return False
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
async def _deliver_audio():
|
|
await asyncio.sleep(0.01)
|
|
await self.append_to_audio_context(
|
|
context_id,
|
|
TTSAudioRawFrame(
|
|
audio=_FAKE_AUDIO,
|
|
sample_rate=_SAMPLE_RATE,
|
|
num_channels=1,
|
|
context_id=context_id,
|
|
),
|
|
)
|
|
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
|
|
await self.remove_audio_context(context_id)
|
|
|
|
self.create_task(_deliver_audio(), name=f"mock_ws_deliver_{context_id}")
|
|
if False:
|
|
yield # make this an async generator without yielding anything
|
|
|
|
|
|
class MockWebSocketPauseTTSService(TTSService):
|
|
"""Simulates a WebSocket TTS service WITH frame-processing pause (e.g. ElevenLabsTTSService).
|
|
|
|
Identical to MockWebSocketTTSService except pause_frame_processing=True.
|
|
on_audio_context_completed() resumes downstream processing once the full
|
|
audio context has been pushed, guaranteeing FooFrame arrives after TTSStoppedFrame.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(
|
|
push_start_frame=True,
|
|
push_text_frames=False,
|
|
pause_frame_processing=True,
|
|
sample_rate=_SAMPLE_RATE,
|
|
**kwargs,
|
|
)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return False
|
|
|
|
async def on_audio_context_completed(self, context_id: str):
|
|
# Resume frame processing after the audio context is fully played out.
|
|
await self._maybe_resume_frame_processing()
|
|
|
|
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
|
async def _deliver_audio():
|
|
await asyncio.sleep(0.01)
|
|
await self.append_to_audio_context(
|
|
context_id,
|
|
TTSAudioRawFrame(
|
|
audio=_FAKE_AUDIO,
|
|
sample_rate=_SAMPLE_RATE,
|
|
num_channels=1,
|
|
context_id=context_id,
|
|
),
|
|
)
|
|
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
|
|
await self.remove_audio_context(context_id)
|
|
|
|
self.create_task(_deliver_audio(), name=f"mock_ws_pause_deliver_{context_id}")
|
|
if False:
|
|
yield
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Assertion helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _assert_group_ordering(
|
|
down_frames: Sequence[Frame],
|
|
expected_groups: list[tuple[str, str]],
|
|
) -> None:
|
|
"""Assert two (or more) TTS+FooFrame groups are in strict order.
|
|
|
|
Args:
|
|
down_frames: All downstream frames received by the test sink.
|
|
expected_groups: List of (tts_text, foo_label) pairs, one per TTSSpeakFrame.
|
|
tts_text is unused in assertions today but included for readability.
|
|
"""
|
|
relevant = [
|
|
f
|
|
for f in down_frames
|
|
if isinstance(
|
|
f, (AggregatedTextFrame, TTSStartedFrame, TTSAudioRawFrame, TTSStoppedFrame, FooFrame)
|
|
)
|
|
]
|
|
|
|
# Locate the FooFrames that delimit groups.
|
|
foo_indices = [i for i, f in enumerate(relevant) if isinstance(f, FooFrame)]
|
|
assert len(foo_indices) == len(expected_groups), (
|
|
f"Expected {len(expected_groups)} FooFrames, got {len(foo_indices)}.\n"
|
|
f"Relevant frames: {[type(f).__name__ for f in relevant]}"
|
|
)
|
|
|
|
# Build groups: everything up to and including each FooFrame.
|
|
groups: list[list[Frame]] = []
|
|
prev = 0
|
|
for idx in foo_indices:
|
|
groups.append(relevant[prev : idx + 1])
|
|
prev = idx + 1
|
|
|
|
for group, (_, foo_label) in zip(groups, expected_groups):
|
|
types = [type(f) for f in group]
|
|
type_names = [t.__name__ for t in types]
|
|
|
|
assert AggregatedTextFrame in types, (
|
|
f"Group {foo_label!r}: missing AggregatedTextFrame. Got: {type_names}"
|
|
)
|
|
assert TTSStartedFrame in types, (
|
|
f"Group {foo_label!r}: missing TTSStartedFrame. Got: {type_names}"
|
|
)
|
|
assert TTSAudioRawFrame in types, (
|
|
f"Group {foo_label!r}: missing TTSAudioRawFrame. Got: {type_names}"
|
|
)
|
|
assert TTSStoppedFrame in types, (
|
|
f"Group {foo_label!r}: missing TTSStoppedFrame. Got: {type_names}"
|
|
)
|
|
|
|
started_idx = types.index(TTSStartedFrame)
|
|
stopped_idx = types.index(TTSStoppedFrame)
|
|
foo_idx = types.index(FooFrame)
|
|
|
|
assert started_idx < stopped_idx, (
|
|
f"Group {foo_label!r}: TTSStartedFrame (pos {started_idx}) must precede "
|
|
f"TTSStoppedFrame (pos {stopped_idx}). Got: {type_names}"
|
|
)
|
|
assert stopped_idx < foo_idx, (
|
|
f"Group {foo_label!r}: TTSStoppedFrame (pos {stopped_idx}) must precede "
|
|
f"FooFrame (pos {foo_idx}). Got: {type_names}"
|
|
)
|
|
|
|
# All frames between TTSStartedFrame and TTSStoppedFrame must be audio.
|
|
mid_types = types[started_idx + 1 : stopped_idx]
|
|
for t in mid_types:
|
|
assert t is TTSAudioRawFrame, (
|
|
f"Group {foo_label!r}: unexpected frame {t.__name__!r} between "
|
|
f"TTSStartedFrame and TTSStoppedFrame. Got: {type_names}"
|
|
)
|
|
|
|
# Check the FooFrame label.
|
|
actual_label = group[foo_idx].label
|
|
assert actual_label == foo_label, (
|
|
f"Expected FooFrame(label={foo_label!r}), got label={actual_label!r}"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_GROUPS = [("test 1", "1"), ("test 2", "2")]
|
|
|
|
|
|
def _make_frames_no_sleep() -> list[Frame]:
|
|
"""Return two TTSSpeakFrame+FooFrame pairs sent back-to-back.
|
|
|
|
Only correct for services that pause downstream processing until the audio
|
|
context is fully consumed (pause_frame_processing=True + on_audio_context_completed).
|
|
"""
|
|
return [
|
|
TTSSpeakFrame(text="test 1", append_to_context=False),
|
|
FooFrame(label="1"),
|
|
TTSSpeakFrame(text="test 2", append_to_context=False),
|
|
FooFrame(label="2"),
|
|
]
|
|
|
|
|
|
def _print_frames_received(frames_received) -> None:
|
|
print("FRAMES RECEIVED:")
|
|
for frame in frames_received[0]:
|
|
print(frame.name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_tts_frame_ordering():
|
|
"""HTTP TTS services yield audio synchronously."""
|
|
tts = MockHttpTTSService()
|
|
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
|
|
|
# only for debugging
|
|
_print_frames_received(frames_received)
|
|
|
|
_assert_group_ordering(frames_received[0], _GROUPS)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_tts_no_pause_frame_ordering():
|
|
"""WebSocket TTS services without pause_frame_processing."""
|
|
tts = MockWebSocketTTSService()
|
|
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
|
_assert_group_ordering(frames_received[0], _GROUPS)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_tts_with_pause_frame_ordering():
|
|
"""WebSocket TTS services with pause_frame_processing=True."""
|
|
tts = MockWebSocketPauseTTSService()
|
|
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
|
_assert_group_ordering(frames_received[0], _GROUPS)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_push_text_llm_response_end_after_tts_text():
|
|
"""LLMFullResponseEndFrame must arrive after all TTSTextFrames.
|
|
|
|
Simulates an LLM response producing multiple sentences through an HTTP TTS
|
|
service with push_text_frames=True. Each sentence is sent as a separate
|
|
TextFrame terminated by a period so the sentence aggregator flushes it.
|
|
The final sentence is flushed by the LLMFullResponseEndFrame itself.
|
|
|
|
Expected downstream ordering:
|
|
LLMFullResponseStartFrame
|
|
... TTSTextFrame (per sentence) ...
|
|
LLMFullResponseEndFrame ← must come AFTER all TTSTextFrames
|
|
"""
|
|
tts = MockHttpPushTextTTSService()
|
|
|
|
# Two sentences: the first ends with a period (triggers aggregator flush),
|
|
# the second does NOT (will be flushed by LLMFullResponseEndFrame).
|
|
frames_to_send = [
|
|
LLMFullResponseStartFrame(),
|
|
TextFrame(text="Hello there. "),
|
|
TextFrame(text="How are you?"),
|
|
LLMFullResponseEndFrame(),
|
|
]
|
|
frames_received = await run_test(tts, frames_to_send=frames_to_send)
|
|
down = frames_received[0]
|
|
|
|
# Collect relevant frame types for ordering check.
|
|
relevant = [
|
|
f
|
|
for f in down
|
|
if isinstance(f, (LLMFullResponseStartFrame, TTSTextFrame, LLMFullResponseEndFrame))
|
|
]
|
|
type_names = [type(f).__name__ for f in relevant]
|
|
|
|
# There should be exactly one LLMFullResponseStartFrame, 2 TTSTextFrames, 1 LLMFullResponseEndFrame.
|
|
tts_text_frames = [f for f in relevant if isinstance(f, TTSTextFrame)]
|
|
end_frames = [f for f in relevant if isinstance(f, LLMFullResponseEndFrame)]
|
|
start_frames = [f for f in relevant if isinstance(f, LLMFullResponseStartFrame)]
|
|
|
|
assert len(start_frames) == 1, (
|
|
f"Expected 1 LLMFullResponseStartFrame, got {len(start_frames)}: {type_names}"
|
|
)
|
|
assert len(tts_text_frames) == 2, (
|
|
f"Expected 2 TTSTextFrames, got {len(tts_text_frames)}: {type_names}"
|
|
)
|
|
assert len(end_frames) == 1, (
|
|
f"Expected 1 LLMFullResponseEndFrame, got {len(end_frames)}: {type_names}"
|
|
)
|
|
|
|
# The critical check: LLMFullResponseEndFrame must come after ALL TTSTextFrames.
|
|
end_idx = relevant.index(end_frames[0])
|
|
last_tts_text_idx = max(relevant.index(f) for f in tts_text_frames)
|
|
|
|
assert last_tts_text_idx < end_idx, (
|
|
f"LLMFullResponseEndFrame (pos {end_idx}) must come after the last "
|
|
f"TTSTextFrame (pos {last_tts_text_idx}). Got: {type_names}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|