Fixing TTS frame order.
This commit is contained in:
@@ -43,6 +43,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
@@ -557,9 +558,9 @@ class TTSService(AIService):
|
||||
"""
|
||||
await super().stop(frame)
|
||||
if self._audio_context_task:
|
||||
# Indicate no more audio contexts are available; this will end the
|
||||
# task cleanly after all contexts have been processed.
|
||||
await self._contexts_queue.put(None)
|
||||
# Sentinel None shuts down the serialization queue once all
|
||||
# pending contexts and frames have been processed.
|
||||
await self._serialization_queue.put(None)
|
||||
await self._audio_context_task
|
||||
self._audio_context_task = None
|
||||
if self._stop_frame_task:
|
||||
@@ -791,7 +792,15 @@ class TTSService(AIService):
|
||||
await self._maybe_resume_frame_processing()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
if direction == FrameDirection.DOWNSTREAM and not isinstance(frame, SystemFrame):
|
||||
# Route non-system downstream frames through the serialization queue so they
|
||||
# are emitted in the same order they arrive relative to any audio contexts that
|
||||
# are already queued (e.g. a FooFrame sent right after a TTSSpeakFrame must
|
||||
# not overtake the TTSStartedFrame / TTSAudioRawFrame / TTSStoppedFrame
|
||||
# sequence from that speak frame).
|
||||
await self._serialization_queue.put(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame downstream with TTS-specific handling.
|
||||
@@ -994,7 +1003,10 @@ class TTSService(AIService):
|
||||
# is spoken, so we set append_to_context to False.
|
||||
src_frame.append_to_context = False
|
||||
src_frame.context_id = context_id
|
||||
await self.push_frame(src_frame)
|
||||
# Route AggregatedTextFrame through the serialization queue so it is emitted
|
||||
# immediately before the TTSStartedFrame of the audio context it describes,
|
||||
# rather than racing ahead of audio frames from a previous context.
|
||||
await self._serialization_queue.put(src_frame)
|
||||
|
||||
# Note: Text transformations are meant to only affect the text sent to the TTS for
|
||||
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
|
||||
@@ -1203,7 +1215,7 @@ class TTSService(AIService):
|
||||
Args:
|
||||
context_id: Unique identifier for the audio context.
|
||||
"""
|
||||
await self._contexts_queue.put(context_id)
|
||||
await self._serialization_queue.put(context_id)
|
||||
self._audio_contexts[context_id] = asyncio.Queue()
|
||||
logger.trace(f"{self} created audio context {context_id}")
|
||||
|
||||
@@ -1295,7 +1307,14 @@ class TTSService(AIService):
|
||||
|
||||
def _create_audio_context_task(self):
|
||||
if not self._audio_context_task:
|
||||
self._contexts_queue: asyncio.Queue = asyncio.Queue()
|
||||
# Single FIFO queue that serializes everything the TTS service emits downstream.
|
||||
# Items can be:
|
||||
# str – an audio context ID: process the per-context audio queue in full before
|
||||
# moving on (see _handle_audio_context).
|
||||
# Frame – a non-system downstream frame (e.g. AggregatedTextFrame, FooFrame) that
|
||||
# must be emitted in-order relative to surrounding audio contexts.
|
||||
# None – shutdown sentinel (sent by stop()).
|
||||
self._serialization_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._audio_contexts: Dict[str, asyncio.Queue] = {}
|
||||
self._audio_context_task = self.create_task(self._audio_context_task_handler())
|
||||
|
||||
@@ -1305,13 +1324,26 @@ class TTSService(AIService):
|
||||
self._audio_context_task = None
|
||||
|
||||
async def _audio_context_task_handler(self):
|
||||
"""In this task we process audio contexts in order."""
|
||||
"""Drain the serialization queue, preserving downstream frame order.
|
||||
|
||||
The queue carries three kinds of items (see _create_audio_context_task):
|
||||
|
||||
* str – audio context ID: block until all audio for that context has been
|
||||
pushed downstream, then call on_audio_context_completed().
|
||||
* Frame – a non-system downstream frame that must be emitted at this exact
|
||||
position in the output stream (e.g. AggregatedTextFrame preceding
|
||||
its audio, or an arbitrary frame that arrived between two speak frames).
|
||||
* None – shutdown sentinel; exit the loop once reached.
|
||||
"""
|
||||
running = True
|
||||
while running:
|
||||
context_id = await self._contexts_queue.get()
|
||||
self._playing_context_id = context_id
|
||||
context_value = await self._serialization_queue.get()
|
||||
if isinstance(context_value, Frame):
|
||||
await self.push_frame(context_value)
|
||||
elif isinstance(context_value, str):
|
||||
context_id = context_value
|
||||
self._playing_context_id = context_id
|
||||
|
||||
if context_id:
|
||||
# Process the audio context until the context doesn't have more
|
||||
# audio available (i.e. we find None).
|
||||
await self._handle_audio_context(context_id)
|
||||
@@ -1323,7 +1355,7 @@ class TTSService(AIService):
|
||||
else:
|
||||
running = False
|
||||
|
||||
self._contexts_queue.task_done()
|
||||
self._serialization_queue.task_done()
|
||||
|
||||
async def _handle_audio_context(self, context_id: str):
|
||||
"""Process items from an audio context queue until it is exhausted."""
|
||||
|
||||
315
tests/test_tts_frame_ordering.py
Normal file
315
tests/test_tts_frame_ordering.py
Normal file
@@ -0,0 +1,315 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for frame ordering across TTS service types.
|
||||
|
||||
Covers three patterns:
|
||||
- HTTP TTS services (e.g. CartesiaHttpTTSService): yield audio frames synchronously.
|
||||
- WebSocket TTS services without pause (e.g. CartesiaTTSService): deliver audio via
|
||||
append_to_audio_context from a background receive loop, no frame-processing pause.
|
||||
- WebSocket TTS services with pause (e.g. ElevenLabsTTSService): same delivery
|
||||
mechanism, but pause downstream frame processing while audio is in flight.
|
||||
|
||||
For all three patterns we verify:
|
||||
AggregatedTextFrame → TTSStartedFrame → TTSAudioRawFrame (1+) → TTSStoppedFrame → FooFrame
|
||||
|
||||
repeated for each TTSSpeakFrame, with no cross-group contamination.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, List, Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
DataFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test-only frame
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FAKE_AUDIO = b"\x00\x01" * 320 # 320 bytes of silence
|
||||
_SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
@dataclass
|
||||
class FooFrame(DataFrame):
|
||||
"""Marker frame used to verify relative ordering against TTS audio frames."""
|
||||
|
||||
label: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock TTS services
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockHttpTTSService(TTSService):
|
||||
"""Simulates an HTTP TTS service (e.g. CartesiaHttpTTSService).
|
||||
|
||||
Audio frames are yielded synchronously from run_tts(), so the audio context
|
||||
is fully populated before the next downstream frame is processed.
|
||||
TTSStoppedFrame is appended by the base class in on_turn_context_completed()
|
||||
once it detects _is_yielding_frames_synchronously is True.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
push_text_frames=False,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
yield TTSAudioRawFrame(
|
||||
audio=_FAKE_AUDIO,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
|
||||
class MockWebSocketTTSService(TTSService):
|
||||
"""Simulates a WebSocket TTS service without frame-processing pause (e.g. CartesiaTTSService).
|
||||
|
||||
run_tts() is an empty async generator (signals async delivery). A background
|
||||
task appends audio frames and the TTSStoppedFrame to the audio context after a
|
||||
short delay, mimicking real WebSocket receive-loop behaviour.
|
||||
pause_frame_processing=False means downstream frames (FooFrame) are NOT held.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=False,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
async def _deliver_audio():
|
||||
await asyncio.sleep(0.01)
|
||||
await self.append_to_audio_context(
|
||||
context_id,
|
||||
TTSAudioRawFrame(
|
||||
audio=_FAKE_AUDIO,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
),
|
||||
)
|
||||
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
|
||||
await self.remove_audio_context(context_id)
|
||||
|
||||
self.create_task(_deliver_audio(), name=f"mock_ws_deliver_{context_id}")
|
||||
if False:
|
||||
yield # make this an async generator without yielding anything
|
||||
|
||||
|
||||
class MockWebSocketPauseTTSService(TTSService):
|
||||
"""Simulates a WebSocket TTS service WITH frame-processing pause (e.g. ElevenLabsTTSService).
|
||||
|
||||
Identical to MockWebSocketTTSService except pause_frame_processing=True.
|
||||
on_audio_context_completed() resumes downstream processing once the full
|
||||
audio context has been pushed, guaranteeing FooFrame arrives after TTSStoppedFrame.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def on_audio_context_completed(self, context_id: str):
|
||||
# Resume frame processing after the audio context is fully played out.
|
||||
await self._maybe_resume_frame_processing()
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
async def _deliver_audio():
|
||||
await asyncio.sleep(0.01)
|
||||
await self.append_to_audio_context(
|
||||
context_id,
|
||||
TTSAudioRawFrame(
|
||||
audio=_FAKE_AUDIO,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
),
|
||||
)
|
||||
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
|
||||
await self.remove_audio_context(context_id)
|
||||
|
||||
self.create_task(_deliver_audio(), name=f"mock_ws_pause_deliver_{context_id}")
|
||||
if False:
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Assertion helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assert_group_ordering(
|
||||
down_frames: Sequence[Frame],
|
||||
expected_groups: List[Tuple[str, str]],
|
||||
) -> None:
|
||||
"""Assert two (or more) TTS+FooFrame groups are in strict order.
|
||||
|
||||
Args:
|
||||
down_frames: All downstream frames received by the test sink.
|
||||
expected_groups: List of (tts_text, foo_label) pairs, one per TTSSpeakFrame.
|
||||
tts_text is unused in assertions today but included for readability.
|
||||
"""
|
||||
relevant = [
|
||||
f
|
||||
for f in down_frames
|
||||
if isinstance(
|
||||
f, (AggregatedTextFrame, TTSStartedFrame, TTSAudioRawFrame, TTSStoppedFrame, FooFrame)
|
||||
)
|
||||
]
|
||||
|
||||
# Locate the FooFrames that delimit groups.
|
||||
foo_indices = [i for i, f in enumerate(relevant) if isinstance(f, FooFrame)]
|
||||
assert len(foo_indices) == len(expected_groups), (
|
||||
f"Expected {len(expected_groups)} FooFrames, got {len(foo_indices)}.\n"
|
||||
f"Relevant frames: {[type(f).__name__ for f in relevant]}"
|
||||
)
|
||||
|
||||
# Build groups: everything up to and including each FooFrame.
|
||||
groups: List[List[Frame]] = []
|
||||
prev = 0
|
||||
for idx in foo_indices:
|
||||
groups.append(relevant[prev : idx + 1])
|
||||
prev = idx + 1
|
||||
|
||||
for group, (_, foo_label) in zip(groups, expected_groups):
|
||||
types = [type(f) for f in group]
|
||||
type_names = [t.__name__ for t in types]
|
||||
|
||||
assert AggregatedTextFrame in types, (
|
||||
f"Group {foo_label!r}: missing AggregatedTextFrame. Got: {type_names}"
|
||||
)
|
||||
assert TTSStartedFrame in types, (
|
||||
f"Group {foo_label!r}: missing TTSStartedFrame. Got: {type_names}"
|
||||
)
|
||||
assert TTSAudioRawFrame in types, (
|
||||
f"Group {foo_label!r}: missing TTSAudioRawFrame. Got: {type_names}"
|
||||
)
|
||||
assert TTSStoppedFrame in types, (
|
||||
f"Group {foo_label!r}: missing TTSStoppedFrame. Got: {type_names}"
|
||||
)
|
||||
|
||||
started_idx = types.index(TTSStartedFrame)
|
||||
stopped_idx = types.index(TTSStoppedFrame)
|
||||
foo_idx = types.index(FooFrame)
|
||||
|
||||
assert started_idx < stopped_idx, (
|
||||
f"Group {foo_label!r}: TTSStartedFrame (pos {started_idx}) must precede "
|
||||
f"TTSStoppedFrame (pos {stopped_idx}). Got: {type_names}"
|
||||
)
|
||||
assert stopped_idx < foo_idx, (
|
||||
f"Group {foo_label!r}: TTSStoppedFrame (pos {stopped_idx}) must precede "
|
||||
f"FooFrame (pos {foo_idx}). Got: {type_names}"
|
||||
)
|
||||
|
||||
# All frames between TTSStartedFrame and TTSStoppedFrame must be audio.
|
||||
mid_types = types[started_idx + 1 : stopped_idx]
|
||||
for t in mid_types:
|
||||
assert t is TTSAudioRawFrame, (
|
||||
f"Group {foo_label!r}: unexpected frame {t.__name__!r} between "
|
||||
f"TTSStartedFrame and TTSStoppedFrame. Got: {type_names}"
|
||||
)
|
||||
|
||||
# Check the FooFrame label.
|
||||
actual_label = group[foo_idx].label
|
||||
assert actual_label == foo_label, (
|
||||
f"Expected FooFrame(label={foo_label!r}), got label={actual_label!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GROUPS = [("test 1", "1"), ("test 2", "2")]
|
||||
|
||||
|
||||
def _make_frames_no_sleep() -> List[Frame]:
|
||||
"""Return two TTSSpeakFrame+FooFrame pairs sent back-to-back.
|
||||
|
||||
Only correct for services that pause downstream processing until the audio
|
||||
context is fully consumed (pause_frame_processing=True + on_audio_context_completed).
|
||||
"""
|
||||
return [
|
||||
TTSSpeakFrame(text="test 1", append_to_context=False),
|
||||
FooFrame(label="1"),
|
||||
TTSSpeakFrame(text="test 2", append_to_context=False),
|
||||
FooFrame(label="2"),
|
||||
]
|
||||
|
||||
|
||||
def _print_frames_received(frames_received) -> None:
|
||||
print("FRAMES RECEIVED:")
|
||||
for frame in frames_received[0]:
|
||||
print(frame.name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_tts_frame_ordering():
|
||||
"""HTTP TTS services yield audio synchronously."""
|
||||
tts = MockHttpTTSService()
|
||||
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
||||
|
||||
# only for debugging
|
||||
_print_frames_received(frames_received)
|
||||
|
||||
_assert_group_ordering(frames_received[0], _GROUPS)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_tts_no_pause_frame_ordering():
|
||||
"""WebSocket TTS services without pause_frame_processing."""
|
||||
tts = MockWebSocketTTSService()
|
||||
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
||||
_assert_group_ordering(frames_received[0], _GROUPS)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_tts_with_pause_frame_ordering():
|
||||
"""WebSocket TTS services with pause_frame_processing=True."""
|
||||
tts = MockWebSocketPauseTTSService()
|
||||
frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep())
|
||||
_assert_group_ordering(frames_received[0], _GROUPS)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user