diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index f93c13d79..303072373 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -43,6 +43,7 @@ from pipecat.frames.frames import ( LLMFullResponseEndFrame, LLMFullResponseStartFrame, StartFrame, + SystemFrame, TextFrame, TranscriptionFrame, TTSAudioRawFrame, @@ -557,9 +558,9 @@ class TTSService(AIService): """ await super().stop(frame) if self._audio_context_task: - # Indicate no more audio contexts are available; this will end the - # task cleanly after all contexts have been processed. - await self._contexts_queue.put(None) + # Sentinel None shuts down the serialization queue once all + # pending contexts and frames have been processed. + await self._serialization_queue.put(None) await self._audio_context_task self._audio_context_task = None if self._stop_frame_task: @@ -791,7 +792,15 @@ class TTSService(AIService): await self._maybe_resume_frame_processing() await self.push_frame(frame, direction) else: - await self.push_frame(frame, direction) + if direction == FrameDirection.DOWNSTREAM and not isinstance(frame, SystemFrame): + # Route non-system downstream frames through the serialization queue so they + # are emitted in the same order they arrive relative to any audio contexts that + # are already queued (e.g. a FooFrame sent right after a TTSSpeakFrame must + # not overtake the TTSStartedFrame / TTSAudioRawFrame / TTSStoppedFrame + # sequence from that speak frame). + await self._serialization_queue.put(frame) + else: + await self.push_frame(frame, direction) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): """Push a frame downstream with TTS-specific handling. @@ -994,7 +1003,10 @@ class TTSService(AIService): # is spoken, so we set append_to_context to False. src_frame.append_to_context = False src_frame.context_id = context_id - await self.push_frame(src_frame) + # Route AggregatedTextFrame through the serialization queue so it is emitted + # immediately before the TTSStartedFrame of the audio context it describes, + # rather than racing ahead of audio frames from a previous context. + await self._serialization_queue.put(src_frame) # Note: Text transformations are meant to only affect the text sent to the TTS for # TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting @@ -1203,7 +1215,7 @@ class TTSService(AIService): Args: context_id: Unique identifier for the audio context. """ - await self._contexts_queue.put(context_id) + await self._serialization_queue.put(context_id) self._audio_contexts[context_id] = asyncio.Queue() logger.trace(f"{self} created audio context {context_id}") @@ -1295,7 +1307,14 @@ class TTSService(AIService): def _create_audio_context_task(self): if not self._audio_context_task: - self._contexts_queue: asyncio.Queue = asyncio.Queue() + # Single FIFO queue that serializes everything the TTS service emits downstream. + # Items can be: + # str – an audio context ID: process the per-context audio queue in full before + # moving on (see _handle_audio_context). + # Frame – a non-system downstream frame (e.g. AggregatedTextFrame, FooFrame) that + # must be emitted in-order relative to surrounding audio contexts. + # None – shutdown sentinel (sent by stop()). + self._serialization_queue: asyncio.Queue = asyncio.Queue() self._audio_contexts: Dict[str, asyncio.Queue] = {} self._audio_context_task = self.create_task(self._audio_context_task_handler()) @@ -1305,13 +1324,26 @@ class TTSService(AIService): self._audio_context_task = None async def _audio_context_task_handler(self): - """In this task we process audio contexts in order.""" + """Drain the serialization queue, preserving downstream frame order. + + The queue carries three kinds of items (see _create_audio_context_task): + + * str – audio context ID: block until all audio for that context has been + pushed downstream, then call on_audio_context_completed(). + * Frame – a non-system downstream frame that must be emitted at this exact + position in the output stream (e.g. AggregatedTextFrame preceding + its audio, or an arbitrary frame that arrived between two speak frames). + * None – shutdown sentinel; exit the loop once reached. + """ running = True while running: - context_id = await self._contexts_queue.get() - self._playing_context_id = context_id + context_value = await self._serialization_queue.get() + if isinstance(context_value, Frame): + await self.push_frame(context_value) + elif isinstance(context_value, str): + context_id = context_value + self._playing_context_id = context_id - if context_id: # Process the audio context until the context doesn't have more # audio available (i.e. we find None). await self._handle_audio_context(context_id) @@ -1323,7 +1355,7 @@ class TTSService(AIService): else: running = False - self._contexts_queue.task_done() + self._serialization_queue.task_done() async def _handle_audio_context(self, context_id: str): """Process items from an audio context queue until it is exhausted.""" diff --git a/tests/test_tts_frame_ordering.py b/tests/test_tts_frame_ordering.py new file mode 100644 index 000000000..52d3df4e7 --- /dev/null +++ b/tests/test_tts_frame_ordering.py @@ -0,0 +1,315 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Tests for frame ordering across TTS service types. + +Covers three patterns: +- HTTP TTS services (e.g. CartesiaHttpTTSService): yield audio frames synchronously. +- WebSocket TTS services without pause (e.g. CartesiaTTSService): deliver audio via + append_to_audio_context from a background receive loop, no frame-processing pause. +- WebSocket TTS services with pause (e.g. ElevenLabsTTSService): same delivery + mechanism, but pause downstream frame processing while audio is in flight. + +For all three patterns we verify: + AggregatedTextFrame → TTSStartedFrame → TTSAudioRawFrame (1+) → TTSStoppedFrame → FooFrame + +repeated for each TTSSpeakFrame, with no cross-group contamination. +""" + +import asyncio +import unittest +from dataclasses import dataclass +from typing import AsyncGenerator, List, Sequence, Tuple + +import pytest + +from pipecat.frames.frames import ( + AggregatedTextFrame, + DataFrame, + Frame, + TTSAudioRawFrame, + TTSSpeakFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.tts_service import TTSService +from pipecat.tests.utils import run_test + +# --------------------------------------------------------------------------- +# Test-only frame +# --------------------------------------------------------------------------- + +_FAKE_AUDIO = b"\x00\x01" * 320 # 320 bytes of silence +_SAMPLE_RATE = 16000 + + +@dataclass +class FooFrame(DataFrame): + """Marker frame used to verify relative ordering against TTS audio frames.""" + + label: str = "" + + +# --------------------------------------------------------------------------- +# Mock TTS services +# --------------------------------------------------------------------------- + + +class MockHttpTTSService(TTSService): + """Simulates an HTTP TTS service (e.g. CartesiaHttpTTSService). + + Audio frames are yielded synchronously from run_tts(), so the audio context + is fully populated before the next downstream frame is processed. + TTSStoppedFrame is appended by the base class in on_turn_context_completed() + once it detects _is_yielding_frames_synchronously is True. + """ + + def __init__(self, **kwargs): + super().__init__( + push_start_frame=True, + push_stop_frames=True, + push_text_frames=False, + sample_rate=_SAMPLE_RATE, + **kwargs, + ) + + def can_generate_metrics(self) -> bool: + return False + + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: + yield TTSAudioRawFrame( + audio=_FAKE_AUDIO, + sample_rate=_SAMPLE_RATE, + num_channels=1, + context_id=context_id, + ) + + +class MockWebSocketTTSService(TTSService): + """Simulates a WebSocket TTS service without frame-processing pause (e.g. CartesiaTTSService). + + run_tts() is an empty async generator (signals async delivery). A background + task appends audio frames and the TTSStoppedFrame to the audio context after a + short delay, mimicking real WebSocket receive-loop behaviour. + pause_frame_processing=False means downstream frames (FooFrame) are NOT held. + """ + + def __init__(self, **kwargs): + super().__init__( + push_start_frame=True, + push_text_frames=False, + pause_frame_processing=False, + sample_rate=_SAMPLE_RATE, + **kwargs, + ) + + def can_generate_metrics(self) -> bool: + return False + + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: + async def _deliver_audio(): + await asyncio.sleep(0.01) + await self.append_to_audio_context( + context_id, + TTSAudioRawFrame( + audio=_FAKE_AUDIO, + sample_rate=_SAMPLE_RATE, + num_channels=1, + context_id=context_id, + ), + ) + await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id)) + await self.remove_audio_context(context_id) + + self.create_task(_deliver_audio(), name=f"mock_ws_deliver_{context_id}") + if False: + yield # make this an async generator without yielding anything + + +class MockWebSocketPauseTTSService(TTSService): + """Simulates a WebSocket TTS service WITH frame-processing pause (e.g. ElevenLabsTTSService). + + Identical to MockWebSocketTTSService except pause_frame_processing=True. + on_audio_context_completed() resumes downstream processing once the full + audio context has been pushed, guaranteeing FooFrame arrives after TTSStoppedFrame. + """ + + def __init__(self, **kwargs): + super().__init__( + push_start_frame=True, + push_text_frames=False, + pause_frame_processing=True, + sample_rate=_SAMPLE_RATE, + **kwargs, + ) + + def can_generate_metrics(self) -> bool: + return False + + async def on_audio_context_completed(self, context_id: str): + # Resume frame processing after the audio context is fully played out. + await self._maybe_resume_frame_processing() + + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: + async def _deliver_audio(): + await asyncio.sleep(0.01) + await self.append_to_audio_context( + context_id, + TTSAudioRawFrame( + audio=_FAKE_AUDIO, + sample_rate=_SAMPLE_RATE, + num_channels=1, + context_id=context_id, + ), + ) + await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id)) + await self.remove_audio_context(context_id) + + self.create_task(_deliver_audio(), name=f"mock_ws_pause_deliver_{context_id}") + if False: + yield + + +# --------------------------------------------------------------------------- +# Assertion helper +# --------------------------------------------------------------------------- + + +def _assert_group_ordering( + down_frames: Sequence[Frame], + expected_groups: List[Tuple[str, str]], +) -> None: + """Assert two (or more) TTS+FooFrame groups are in strict order. + + Args: + down_frames: All downstream frames received by the test sink. + expected_groups: List of (tts_text, foo_label) pairs, one per TTSSpeakFrame. + tts_text is unused in assertions today but included for readability. + """ + relevant = [ + f + for f in down_frames + if isinstance( + f, (AggregatedTextFrame, TTSStartedFrame, TTSAudioRawFrame, TTSStoppedFrame, FooFrame) + ) + ] + + # Locate the FooFrames that delimit groups. + foo_indices = [i for i, f in enumerate(relevant) if isinstance(f, FooFrame)] + assert len(foo_indices) == len(expected_groups), ( + f"Expected {len(expected_groups)} FooFrames, got {len(foo_indices)}.\n" + f"Relevant frames: {[type(f).__name__ for f in relevant]}" + ) + + # Build groups: everything up to and including each FooFrame. + groups: List[List[Frame]] = [] + prev = 0 + for idx in foo_indices: + groups.append(relevant[prev : idx + 1]) + prev = idx + 1 + + for group, (_, foo_label) in zip(groups, expected_groups): + types = [type(f) for f in group] + type_names = [t.__name__ for t in types] + + assert AggregatedTextFrame in types, ( + f"Group {foo_label!r}: missing AggregatedTextFrame. Got: {type_names}" + ) + assert TTSStartedFrame in types, ( + f"Group {foo_label!r}: missing TTSStartedFrame. Got: {type_names}" + ) + assert TTSAudioRawFrame in types, ( + f"Group {foo_label!r}: missing TTSAudioRawFrame. Got: {type_names}" + ) + assert TTSStoppedFrame in types, ( + f"Group {foo_label!r}: missing TTSStoppedFrame. Got: {type_names}" + ) + + started_idx = types.index(TTSStartedFrame) + stopped_idx = types.index(TTSStoppedFrame) + foo_idx = types.index(FooFrame) + + assert started_idx < stopped_idx, ( + f"Group {foo_label!r}: TTSStartedFrame (pos {started_idx}) must precede " + f"TTSStoppedFrame (pos {stopped_idx}). Got: {type_names}" + ) + assert stopped_idx < foo_idx, ( + f"Group {foo_label!r}: TTSStoppedFrame (pos {stopped_idx}) must precede " + f"FooFrame (pos {foo_idx}). Got: {type_names}" + ) + + # All frames between TTSStartedFrame and TTSStoppedFrame must be audio. + mid_types = types[started_idx + 1 : stopped_idx] + for t in mid_types: + assert t is TTSAudioRawFrame, ( + f"Group {foo_label!r}: unexpected frame {t.__name__!r} between " + f"TTSStartedFrame and TTSStoppedFrame. Got: {type_names}" + ) + + # Check the FooFrame label. + actual_label = group[foo_idx].label + assert actual_label == foo_label, ( + f"Expected FooFrame(label={foo_label!r}), got label={actual_label!r}" + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +_GROUPS = [("test 1", "1"), ("test 2", "2")] + + +def _make_frames_no_sleep() -> List[Frame]: + """Return two TTSSpeakFrame+FooFrame pairs sent back-to-back. + + Only correct for services that pause downstream processing until the audio + context is fully consumed (pause_frame_processing=True + on_audio_context_completed). + """ + return [ + TTSSpeakFrame(text="test 1", append_to_context=False), + FooFrame(label="1"), + TTSSpeakFrame(text="test 2", append_to_context=False), + FooFrame(label="2"), + ] + + +def _print_frames_received(frames_received) -> None: + print("FRAMES RECEIVED:") + for frame in frames_received[0]: + print(frame.name) + + +@pytest.mark.asyncio +async def test_http_tts_frame_ordering(): + """HTTP TTS services yield audio synchronously.""" + tts = MockHttpTTSService() + frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep()) + + # only for debugging + _print_frames_received(frames_received) + + _assert_group_ordering(frames_received[0], _GROUPS) + + +@pytest.mark.asyncio +async def test_websocket_tts_no_pause_frame_ordering(): + """WebSocket TTS services without pause_frame_processing.""" + tts = MockWebSocketTTSService() + frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep()) + _assert_group_ordering(frames_received[0], _GROUPS) + + +@pytest.mark.asyncio +async def test_websocket_tts_with_pause_frame_ordering(): + """WebSocket TTS services with pause_frame_processing=True.""" + tts = MockWebSocketPauseTTSService() + frames_received = await run_test(tts, frames_to_send=_make_frames_no_sleep()) + _assert_group_ordering(frames_received[0], _GROUPS) + + +if __name__ == "__main__": + unittest.main()