test: add automated tests for word tracking, frame sequencing, and Cartesia TTS
Adds tests for AggregatedFrameSequencer, WordCompletionTracker, and word_timestamp_utils (including CJK language scenarios). Updates existing Cartesia TTS and TTS frame ordering tests to cover the new behaviours.
This commit is contained in:
612
tests/test_aggregated_frame_sequencer.py
Normal file
612
tests/test_aggregated_frame_sequencer.py
Normal file
@@ -0,0 +1,612 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for AggregatedFrameSequencer.
|
||||
|
||||
All methods on the sequencer are synchronous and return lists of frames,
|
||||
so no async machinery is needed here.
|
||||
|
||||
Test groups:
|
||||
- register_skipped: immediate flush vs. blocked by a preceding spoken slot
|
||||
- register_spoken / complete_spoken_slot: push_text_frames=True path
|
||||
- flush: pts propagation, transport_destination, stops at incomplete spoken slot
|
||||
- process_word: normal, completing, passthrough, raw_text propagation
|
||||
- process_word overflow: single token spanning two slot boundaries
|
||||
- process_word force-complete via belongs_here failure
|
||||
- force_complete: remaining text emission, raw_text, corrupt raw discard, slot ordering
|
||||
- clear: resets all state
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import AggregatedTextFrame, AggregationType, TTSTextFrame
|
||||
from pipecat.utils.context.aggregated_frame_sequencer import AggregatedFrameSequencer
|
||||
from pipecat.utils.context.word_completion_tracker import WordCompletionTracker
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seq() -> AggregatedFrameSequencer:
|
||||
return AggregatedFrameSequencer(name="test")
|
||||
|
||||
|
||||
def _spoken_frame(text: str) -> AggregatedTextFrame:
|
||||
return AggregatedTextFrame(text, AggregationType.SENTENCE)
|
||||
|
||||
|
||||
def _skipped_frame(text: str) -> AggregatedTextFrame:
|
||||
return AggregatedTextFrame(text, "code")
|
||||
|
||||
|
||||
def _tracker(tts_text: str, llm_text: str | None = None) -> WordCompletionTracker:
|
||||
return WordCompletionTracker(tts_text, llm_text=llm_text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_skipped
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterSkipped(unittest.TestCase):
|
||||
def test_emits_immediately_with_empty_queue(self):
|
||||
seq = _seq()
|
||||
frame = _skipped_frame("code block")
|
||||
result = seq.register_skipped(frame, "ctx1", None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIs(result[0], frame)
|
||||
|
||||
def test_sets_append_to_context_true(self):
|
||||
seq = _seq()
|
||||
frame = _skipped_frame("code")
|
||||
seq.register_skipped(frame, "ctx1", None)
|
||||
self.assertTrue(frame.append_to_context)
|
||||
|
||||
def test_sets_context_id_on_frame(self):
|
||||
seq = _seq()
|
||||
frame = _skipped_frame("code")
|
||||
seq.register_skipped(frame, "ctx42", None)
|
||||
self.assertEqual(frame.context_id, "ctx42")
|
||||
|
||||
def test_sets_transport_destination(self):
|
||||
seq = _seq()
|
||||
frame = _skipped_frame("code")
|
||||
result = seq.register_skipped(frame, "ctx1", "dest-A")
|
||||
self.assertEqual(result[0].transport_destination, "dest-A")
|
||||
|
||||
def test_blocked_by_incomplete_spoken_slot(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
|
||||
result = seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_emits_immediately_after_already_complete_spoken_slot(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hi"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.complete_spoken_slot()
|
||||
result = seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_multiple_skipped_before_any_spoken_all_emit(self):
|
||||
seq = _seq()
|
||||
r1 = seq.register_skipped(_skipped_frame("code1"), "ctx1", None)
|
||||
r2 = seq.register_skipped(_skipped_frame("code2"), "ctx2", None)
|
||||
self.assertEqual(len(r1), 1)
|
||||
self.assertEqual(len(r2), 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_spoken / complete_spoken_slot (push_text_frames=True path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompleteSpokenSlot(unittest.TestCase):
|
||||
def test_noop_with_empty_queue(self):
|
||||
seq = _seq()
|
||||
self.assertEqual(seq.complete_spoken_slot(), [])
|
||||
|
||||
def test_marks_slot_complete_and_flushes_skipped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None) # blocked
|
||||
|
||||
result = seq.complete_spoken_slot()
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIs(result[0], skipped)
|
||||
self.assertTrue(skipped.append_to_context)
|
||||
|
||||
def test_only_first_pending_slot_is_marked(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("one"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.register_spoken(_spoken_frame("two"), "ctx2", tracker=None, append_to_context=True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None)
|
||||
|
||||
# ctx2 still blocks the skipped frame
|
||||
result = seq.complete_spoken_slot()
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_skipped_flushes_after_all_preceding_spoken_complete(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("one"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.register_spoken(_spoken_frame("two"), "ctx2", tracker=None, append_to_context=True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None)
|
||||
|
||||
seq.complete_spoken_slot() # completes ctx1
|
||||
result = seq.complete_spoken_slot() # completes ctx2 → flush skipped
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIs(result[0], skipped)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# flush
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlush(unittest.TestCase):
|
||||
def test_empty_queue_returns_empty(self):
|
||||
self.assertEqual(_seq().flush(), [])
|
||||
|
||||
def test_stops_at_incomplete_spoken_slot(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
self.assertEqual(seq.flush(), [])
|
||||
|
||||
def test_last_word_pts_assigned_to_skipped_frame(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
# process_word("hello") completes the spoken slot and calls flush(last_word_pts=77)
|
||||
result = seq.process_word("hello", pts=77, context_id="ctx1")
|
||||
flushed = [f for f in result if isinstance(f, AggregatedTextFrame) and f.text == "code"]
|
||||
self.assertEqual(len(flushed), 1)
|
||||
self.assertEqual(flushed[0].pts, 77)
|
||||
|
||||
def test_complete_spoken_slots_are_swept(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.complete_spoken_slot()
|
||||
# Queue should be empty after sweeping the complete spoken slot
|
||||
self.assertEqual(seq._slots, [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_word — basic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessWordBasic(unittest.TestCase):
|
||||
def _seq_with_spoken(self, text: str, ctx: str = "ctx1", append: bool = True):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame(text), ctx, _tracker(text), append)
|
||||
return seq
|
||||
|
||||
def test_returns_tts_text_frame(self):
|
||||
seq = self._seq_with_spoken("hello")
|
||||
result = seq.process_word("hello", pts=100, context_id="ctx1")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TTSTextFrame)
|
||||
|
||||
def test_frame_text_and_pts(self):
|
||||
seq = self._seq_with_spoken("hello")
|
||||
result = seq.process_word("hello", pts=100, context_id="ctx1")
|
||||
self.assertEqual(result[0].text, "hello")
|
||||
self.assertEqual(result[0].pts, 100)
|
||||
|
||||
def test_frame_context_id(self):
|
||||
seq = self._seq_with_spoken("hello", ctx="ctx99")
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx99")
|
||||
self.assertEqual(result[0].context_id, "ctx99")
|
||||
|
||||
def test_append_to_context_true(self):
|
||||
seq = self._seq_with_spoken("hello", append=True)
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx1")
|
||||
self.assertTrue(result[0].append_to_context)
|
||||
|
||||
def test_append_to_context_false(self):
|
||||
seq = self._seq_with_spoken("hello", append=False)
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx1")
|
||||
self.assertFalse(result[0].append_to_context)
|
||||
|
||||
def test_non_completing_word_does_not_flush_skipped(self):
|
||||
seq = self._seq_with_spoken("hello world")
|
||||
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
result = seq.process_word("hello", pts=10, context_id="ctx1")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TTSTextFrame)
|
||||
|
||||
def test_completing_word_flushes_blocked_skipped_frame(self):
|
||||
seq = self._seq_with_spoken("hello")
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
result = seq.process_word("hello", pts=50, context_id="ctx1")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertIsInstance(result[0], TTSTextFrame)
|
||||
self.assertIs(result[1], skipped)
|
||||
|
||||
def test_last_of_multiple_words_flushes_skipped(self):
|
||||
seq = self._seq_with_spoken("hello world")
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
seq.process_word("hello", pts=10, context_id="ctx1")
|
||||
result = seq.process_word("world", pts=20, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
def test_no_active_slot_emits_passthrough(self):
|
||||
seq = _seq()
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx-unknown")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TTSTextFrame)
|
||||
self.assertEqual(result[0].text, "hello")
|
||||
self.assertEqual(result[0].context_id, "ctx-unknown")
|
||||
|
||||
def test_passthrough_uses_default_append_to_context_true(self):
|
||||
seq = _seq()
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx-unknown")
|
||||
self.assertTrue(result[0].append_to_context)
|
||||
|
||||
def test_unrecognised_word_emits_passthrough(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
|
||||
# "zzz" doesn't belong to "hello world" and there is no next slot
|
||||
result = seq.process_word("zzz", pts=5, context_id="ctx1")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].text, "zzz")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_word — raw_text propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessWordRawText(unittest.TestCase):
|
||||
def test_raw_text_split_across_word_frames(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(
|
||||
_spoken_frame("4111 1111"),
|
||||
"ctx1",
|
||||
WordCompletionTracker("4111 1111", llm_text="<card>4111 1111</card>"),
|
||||
append_to_context=True,
|
||||
)
|
||||
r1 = seq.process_word("4111", pts=10, context_id="ctx1")
|
||||
r2 = seq.process_word("1111", pts=20, context_id="ctx1")
|
||||
self.assertEqual(r1[0].raw_text, "<card>4111")
|
||||
last_word_frames = [f for f in r2 if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(last_word_frames[0].raw_text, "1111</card>")
|
||||
|
||||
def test_raw_text_none_when_no_llm_text(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx1")
|
||||
self.assertIsNone(result[0].raw_text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_word — overflow (single token spanning two slots)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessWordOverflow(unittest.TestCase):
|
||||
def test_overflow_produces_two_tts_text_frames(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
|
||||
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
|
||||
|
||||
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
|
||||
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(word_frames), 2)
|
||||
self.assertEqual(word_frames[0].text, "abc")
|
||||
self.assertEqual(word_frames[1].text, "def")
|
||||
|
||||
def test_overflow_assigns_correct_context_ids(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
|
||||
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
|
||||
|
||||
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
|
||||
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(word_frames[0].context_id, "ctx1")
|
||||
self.assertEqual(word_frames[1].context_id, "ctx2")
|
||||
|
||||
def test_overflow_completing_next_slot_flushes_skipped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
|
||||
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None) # blocked behind ctx2
|
||||
|
||||
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
def test_overflow_not_completing_next_slot_does_not_flush_skipped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
|
||||
seq.register_spoken(_spoken_frame("def ghi"), "ctx2", _tracker("def ghi"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None)
|
||||
|
||||
# "abcdef" overflows: "def" goes to ctx2, but ctx2 still expects " ghi"
|
||||
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
|
||||
self.assertFalse(any(f is skipped for f in result))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_word — force-complete via word_belongs_here failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessWordForcesComplete(unittest.TestCase):
|
||||
def test_word_for_next_slot_force_completes_current(self):
|
||||
"""When a word belongs to the next slot but not the current, the current
|
||||
slot is force-completed and the word is routed to the next slot."""
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
|
||||
|
||||
# "world" doesn't belong to ctx1 but belongs to ctx2
|
||||
result = seq.process_word("world", pts=50, context_id="ctx2")
|
||||
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
texts = {f.text for f in word_frames}
|
||||
self.assertIn("world", texts)
|
||||
|
||||
def test_force_complete_then_overflow_flushes_skipped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None)
|
||||
|
||||
# "world" force-completes ctx1 and completes ctx2 via overflow
|
||||
result = seq.process_word("world", pts=50, context_id="ctx2")
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# force_complete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForceComplete(unittest.TestCase):
|
||||
def test_emits_remaining_text_when_word_dropped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
|
||||
seq.process_word("hello", pts=10, context_id="ctx1") # "world" never arrives
|
||||
|
||||
result = seq.force_complete(last_word_pts=50)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "world")
|
||||
self.assertEqual(tts_frames[0].pts, 50)
|
||||
|
||||
def test_emits_full_text_when_no_words_arrived(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
|
||||
|
||||
result = seq.force_complete(last_word_pts=0)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "hello world")
|
||||
|
||||
def test_already_complete_slot_emits_nothing(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hi"), "ctx1", _tracker("hi"), True)
|
||||
seq.process_word("hi", pts=5, context_id="ctx1") # completes normally
|
||||
|
||||
result = seq.force_complete(last_word_pts=10)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_flushes_skipped_frames_after_completing(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
result = seq.force_complete(last_word_pts=20)
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
self.assertTrue(skipped.append_to_context)
|
||||
|
||||
def test_propagates_raw_text(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(
|
||||
_spoken_frame("4111 1111"),
|
||||
"ctx1",
|
||||
WordCompletionTracker("4111 1111", llm_text="<card>4111 1111</card>"),
|
||||
append_to_context=True,
|
||||
)
|
||||
seq.process_word("4111", pts=10, context_id="ctx1") # "1111" never arrives
|
||||
|
||||
result = seq.force_complete(last_word_pts=20)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(tts_frames[0].text, "1111")
|
||||
self.assertEqual(tts_frames[0].raw_text, "1111</card>")
|
||||
|
||||
def test_discards_corrupt_raw_remaining(self):
|
||||
"""raw_remaining is discarded when it does not contain remaining_text."""
|
||||
seq = _seq()
|
||||
# "abc" normalized ≠ "xyz" normalized — any remaining won't be in raw_remaining
|
||||
seq.register_spoken(
|
||||
_spoken_frame("abc"),
|
||||
"ctx1",
|
||||
WordCompletionTracker("abc", llm_text="xyz"),
|
||||
append_to_context=True,
|
||||
)
|
||||
result = seq.force_complete(last_word_pts=0)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "abc")
|
||||
self.assertIsNone(tts_frames[0].raw_text) # discarded due to corruption
|
||||
|
||||
def test_slot_without_tracker_just_marks_complete_and_flushes(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
result = seq.force_complete(last_word_pts=0)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(tts_frames, []) # no tracker → no word frame
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
def test_multiple_incomplete_slots_all_emitted(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
|
||||
|
||||
result = seq.force_complete(last_word_pts=0)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
texts = {f.text for f in tts_frames}
|
||||
self.assertIn("hello", texts)
|
||||
self.assertIn("world", texts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# clear
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClear(unittest.TestCase):
|
||||
def test_clears_slots(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
seq.clear()
|
||||
self.assertEqual(seq._slots, [])
|
||||
|
||||
def test_clears_context_map(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.clear()
|
||||
self.assertEqual(seq._context_append_to_context, {})
|
||||
|
||||
def test_after_clear_skipped_emits_immediately(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.clear()
|
||||
frame = _skipped_frame("code")
|
||||
result = seq.register_skipped(frame, "ctx2", None)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_after_clear_process_word_uses_passthrough(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.clear()
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx1")
|
||||
# No active slot after clear → passthrough
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].text, "hello")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CJK languages — Korean, Japanese, Chinese
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCJKLanguages(unittest.TestCase):
|
||||
"""Sequencer behaviour for CJK language scenarios.
|
||||
|
||||
Korean: Cartesia returns each word as a separate timestamp event (one word
|
||||
per process_word call). Japanese/Chinese: Cartesia merges all characters
|
||||
in one timestamp message into a single combined token before calling
|
||||
process_word.
|
||||
"""
|
||||
|
||||
# --- Korean ---
|
||||
|
||||
def test_korean_word_by_word_completes_slot_and_flushes_skipped(self):
|
||||
"""Korean words fed one at a time complete the spoken slot and unblock a skipped frame."""
|
||||
seq = _seq()
|
||||
sentence = "저는 여러분의 AI 어시스턴트입니다."
|
||||
words = ["저는", "여러분의", "AI", "어시스턴트입니다."]
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
skipped = _skipped_frame("[code]")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
# Skipped stays blocked until the last word arrives
|
||||
for word in words[:-1]:
|
||||
partial = seq.process_word(word, pts=100, context_id="ctx1")
|
||||
self.assertFalse(any(f is skipped for f in partial))
|
||||
|
||||
result = seq.process_word(words[-1], pts=200, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
def test_korean_force_complete_emits_correct_remaining_text(self):
|
||||
"""After one Korean word, force_complete emits the correct unspoken suffix."""
|
||||
seq = _seq()
|
||||
sentence = "저는 여러분의 AI 어시스턴트입니다."
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
seq.process_word("저는", pts=10, context_id="ctx1")
|
||||
|
||||
result = seq.force_complete(last_word_pts=50)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "여러분의 AI 어시스턴트입니다.")
|
||||
self.assertEqual(tts_frames[0].pts, 50)
|
||||
|
||||
# --- Japanese ---
|
||||
|
||||
def test_japanese_combined_groups_complete_spoken_slot(self):
|
||||
"""Two Cartesia-style combined Japanese groups complete the slot and flush skipped."""
|
||||
seq = _seq()
|
||||
sentence = "こんにちは、私はあなたの"
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
skipped = _skipped_frame("[skipped]")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
r1 = seq.process_word("こんにちは、私", pts=100, context_id="ctx1")
|
||||
self.assertFalse(any(f is skipped for f in r1))
|
||||
|
||||
r2 = seq.process_word("はあなたの", pts=200, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in r2))
|
||||
|
||||
def test_japanese_force_complete_emits_remaining_chars(self):
|
||||
"""After the first Japanese combined group, force_complete emits the rest."""
|
||||
seq = _seq()
|
||||
sentence = "こんにちは、私はあなたの"
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
seq.process_word("こんにちは、私", pts=10, context_id="ctx1")
|
||||
|
||||
result = seq.force_complete(last_word_pts=50)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "はあなたの")
|
||||
|
||||
# --- Chinese ---
|
||||
|
||||
def test_chinese_combined_groups_complete_spoken_slot(self):
|
||||
"""Two Cartesia-style combined Chinese groups complete the slot and flush skipped."""
|
||||
seq = _seq()
|
||||
sentence = "你好,我是你的智能"
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
skipped = _skipped_frame("[skipped]")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
r1 = seq.process_word("你好,我是", pts=100, context_id="ctx1")
|
||||
self.assertFalse(any(f is skipped for f in r1))
|
||||
|
||||
r2 = seq.process_word("你的智能", pts=200, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in r2))
|
||||
|
||||
def test_chinese_force_complete_emits_remaining_chars(self):
|
||||
"""After the first Chinese combined group, force_complete emits the rest."""
|
||||
seq = _seq()
|
||||
sentence = "你好,我是你的智能"
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
seq.process_word("你好,我是", pts=10, context_id="ctx1")
|
||||
|
||||
result = seq.force_complete(last_word_pts=50)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "你的智能")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -18,7 +18,7 @@ def _service(language: str) -> CartesiaTTSService:
|
||||
def _process_word_timestamps(
|
||||
words: list[str], starts: list[float], language: str
|
||||
) -> list[tuple[str, float]]:
|
||||
return _service(language)._process_word_timestamps_for_language(words, starts)
|
||||
return _service(language)._normalize_word_timestamps(words, starts)
|
||||
|
||||
|
||||
def _concatenate_processed_timestamps(
|
||||
@@ -27,7 +27,7 @@ def _concatenate_processed_timestamps(
|
||||
service = _service(language)
|
||||
text_parts = []
|
||||
for words, starts in timestamp_groups:
|
||||
processed_timestamps = service._process_word_timestamps_for_language(words, starts)
|
||||
processed_timestamps = service._normalize_word_timestamps(words, starts)
|
||||
includes_inter_frame_spaces = service._word_timestamps_include_inter_frame_spaces()
|
||||
text_parts.extend(
|
||||
TextPartForConcatenation(
|
||||
|
||||
@@ -21,6 +21,13 @@ repeated for each TTSSpeakFrame, with no cross-group contamination.
|
||||
Also covers LLM response flow with push_text_frames=True (non-word-timestamp TTS):
|
||||
verifies TTSTextFrame ordering relative to LLMFullResponseEndFrame.
|
||||
|
||||
Also covers smart-text / WordCompletionTracker features:
|
||||
- Skipped frames (skip_aggregator_types) held until preceding spoken slots complete.
|
||||
- raw_text on AggregatedTextFrame propagated as spans to TTSTextFrames.
|
||||
- Overflow: a single TTS word straddling two AggregatedTextFrame boundaries produces
|
||||
two correctly-attributed TTSTextFrames.
|
||||
- Force-complete safety net: skipped frames flush even when TTS drops word timestamps.
|
||||
|
||||
Also covers the interruption-during-pause deadlock scenario (see test_no_deadlock_on_interrupt_*).
|
||||
"""
|
||||
|
||||
@@ -50,6 +57,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test-only frame
|
||||
@@ -422,7 +430,7 @@ def _assert_group_ordering(
|
||||
# All frames between TTSStartedFrame and TTSStoppedFrame must be audio.
|
||||
mid_types = types[started_idx + 1 : stopped_idx]
|
||||
for t in mid_types:
|
||||
assert t is TTSAudioRawFrame, (
|
||||
assert t in (TTSAudioRawFrame, TTSTextFrame), (
|
||||
f"Group {foo_label!r}: unexpected frame {t.__name__!r} between "
|
||||
f"TTSStartedFrame and TTSStoppedFrame. Got: {type_names}"
|
||||
)
|
||||
@@ -551,7 +559,7 @@ async def test_http_push_text_llm_response_end_after_tts_text():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_word_timestamps_verbatim_tokens():
|
||||
"""HTTP path: text, PTS order, flag, and text-before-audio are all verified.
|
||||
"""HTTP path: text, PTS order, and text-before-audio are all verified.
|
||||
|
||||
Word timestamps arrive in the audio context queue before the audio frame.
|
||||
_handle_audio_context caches them, then flushes when the first audio frame
|
||||
@@ -572,7 +580,6 @@ async def test_http_word_timestamps_verbatim_tokens():
|
||||
audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)]
|
||||
|
||||
assert [f.text for f in tts_text_frames] == ["hello", "world"]
|
||||
assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames)
|
||||
|
||||
pts_values = [f.pts for f in tts_text_frames]
|
||||
assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), (
|
||||
@@ -590,15 +597,14 @@ async def test_http_word_timestamps_verbatim_tokens():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_word_timestamps_punctuation_tokens():
|
||||
"""Verbatim punctuation tokens are preserved with flag=True; default flag is False.
|
||||
"""Punct-only tokens are merged into the preceding word when includes_inter_frame_spaces=True.
|
||||
|
||||
Models the Inworld API scenario: the TTS returns tokens exactly as sent.
|
||||
Space placement rule:
|
||||
- word-follows-word: space is the leading char of the next word (e.g. " world")
|
||||
- word-follows-punctuation: space is the trailing char of the punctuation token
|
||||
(e.g. "! "), so the following word token carries no leading space.
|
||||
The flag must reach every frame and the text must not be modified.
|
||||
Also acts as a regression guard that flag=False is the default.
|
||||
Models the Inworld API scenario: the TTS returns separate space and punctuation
|
||||
tokens. add_word_timestamps calls merge_punct_tokens when includes_inter_frame_spaces
|
||||
is True, collapsing those tokens into the preceding word before the tracker sees them.
|
||||
|
||||
With flag=False (default) tokens are forwarded as-is; the tracker strips leading/
|
||||
trailing whitespace from each frame word via get_word_for_frame().
|
||||
"""
|
||||
verbatim_tokens = [
|
||||
("hello", 0.0),
|
||||
@@ -609,9 +615,9 @@ async def test_http_word_timestamps_punctuation_tokens():
|
||||
(" you", 0.75),
|
||||
("?", 0.9),
|
||||
]
|
||||
expected_texts = ["hello", " world", "! ", "How", " are", " you", "?"]
|
||||
|
||||
# With flag=True: all tokens verbatim, all frames carry the flag.
|
||||
# With flag=True: punct-only tokens ("! " and "?") are merged into the preceding
|
||||
# words (" world" → " world! " and " you" → " you?"), then stripped by the tracker.
|
||||
tts_ifs = _MockWordTimestampHttpTTSService(
|
||||
includes_inter_frame_spaces=True,
|
||||
word_times=verbatim_tokens,
|
||||
@@ -621,12 +627,11 @@ async def test_http_word_timestamps_punctuation_tokens():
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
|
||||
)
|
||||
text_frames_ifs = [f for f in frames_ifs[0] if isinstance(f, TTSTextFrame)]
|
||||
assert [f.text for f in text_frames_ifs] == expected_texts, (
|
||||
"Verbatim tokens must not be modified"
|
||||
assert [f.text for f in text_frames_ifs] == ["hello", "world!", "How", "are", "you?"], (
|
||||
"Punct-only tokens must be merged into the preceding word"
|
||||
)
|
||||
assert all(f.includes_inter_frame_spaces is True for f in text_frames_ifs)
|
||||
|
||||
# With flag=False (default): same tokens, flag must be False on every frame.
|
||||
# With flag=False (default): no merging; tracker strips leading/trailing spaces.
|
||||
tts_plain = _MockWordTimestampHttpTTSService(
|
||||
word_times=verbatim_tokens,
|
||||
)
|
||||
@@ -635,13 +640,12 @@ async def test_http_word_timestamps_punctuation_tokens():
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
|
||||
)
|
||||
text_frames_plain = [f for f in frames_plain[0] if isinstance(f, TTSTextFrame)]
|
||||
assert [f.text for f in text_frames_plain] == expected_texts
|
||||
assert all(f.includes_inter_frame_spaces is False for f in text_frames_plain)
|
||||
assert [f.text for f in text_frames_plain] == ["hello", "world", "!", "How", "are", "you", "?"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_word_timestamps_verbatim_tokens():
|
||||
"""WebSocket path: _WordTimestampEntry carries verbatim text, PTS, and flag.
|
||||
"""WebSocket path: text, PTS order, and text-before-audio are all verified.
|
||||
|
||||
Unlike the HTTP path the word timestamps are sent asynchronously from a
|
||||
background task. They arrive before the audio frame and are cached until
|
||||
@@ -662,7 +666,6 @@ async def test_websocket_word_timestamps_verbatim_tokens():
|
||||
audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)]
|
||||
|
||||
assert [f.text for f in tts_text_frames] == ["hello", "world"]
|
||||
assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames)
|
||||
|
||||
pts_values = [f.pts for f in tts_text_frames]
|
||||
assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), (
|
||||
@@ -678,7 +681,7 @@ async def test_websocket_word_timestamps_verbatim_tokens():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_word_timestamps_punctuation_tokens():
|
||||
"""WebSocket path: verbatim punctuation tokens reach TTSTextFrame unchanged."""
|
||||
"""WebSocket path: punct-only tokens are merged into the preceding word."""
|
||||
verbatim_tokens = [
|
||||
("hello", 0.0),
|
||||
(" world", 0.15),
|
||||
@@ -697,10 +700,443 @@ async def test_websocket_word_timestamps_punctuation_tokens():
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
|
||||
)
|
||||
text_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
assert [f.text for f in text_frames] == ["hello", " world", "! ", "How", " are", " you", "?"], (
|
||||
"Verbatim tokens must not be modified"
|
||||
assert [f.text for f in text_frames] == ["hello", "world!", "How", "are", "you?"], (
|
||||
"Punct-only tokens must be merged into the preceding word"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call word-timestamp mock (for overflow tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MockPerCallWordTimestampHttpTTSService(TTSService):
|
||||
"""HTTP-style TTS where each run_tts() call consumes its own word-time list.
|
||||
|
||||
Designed for tests that need different word tokens per sentence. The
|
||||
``word_times_per_call`` list is consumed in order; an empty inner list means
|
||||
no word-timestamp events are emitted for that call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
word_times_per_call: list[list[tuple[str, float]]],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
push_text_frames=False,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
self._word_times_queue = list(word_times_per_call)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
word_times = self._word_times_queue.pop(0) if self._word_times_queue else []
|
||||
if word_times:
|
||||
await self.add_word_timestamps(word_times, context_id=context_id)
|
||||
yield TTSAudioRawFrame(
|
||||
audio=_FAKE_AUDIO,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: skipped frame ordering (skip_aggregator_types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_skipped_frame_waits_for_spoken_words():
|
||||
"""Skipped frames are held until the preceding spoken slot's word timestamps
|
||||
are all processed, then flushed in order (HTTP / synchronous audio path).
|
||||
|
||||
Sequence sent:
|
||||
AggregatedTextFrame("hello world", SENTENCE) — spoken; yields 2 TTSTextFrames
|
||||
AggregatedTextFrame("some code", "code") — in skip_aggregator_types; must wait
|
||||
|
||||
Expected downstream order:
|
||||
TTSTextFrame("hello")
|
||||
TTSTextFrame("world")
|
||||
AggregatedTextFrame("some code", append_to_context=True)
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService(skip_aggregator_types=["code"])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
|
||||
AggregatedTextFrame("some code", "code"),
|
||||
],
|
||||
)
|
||||
down = frames_received[0]
|
||||
|
||||
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
|
||||
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
|
||||
|
||||
assert [f.text for f in word_frames] == ["hello", "world"]
|
||||
assert len(skipped) == 1
|
||||
assert skipped[0].append_to_context is True
|
||||
|
||||
last_word_idx = max(down.index(f) for f in word_frames)
|
||||
skipped_idx = down.index(skipped[0])
|
||||
assert skipped_idx > last_word_idx, (
|
||||
f"Skipped frame (pos {skipped_idx}) must appear after last word frame (pos {last_word_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_skipped_frame_waits_for_spoken_words():
|
||||
"""Same ordering guarantee on the WebSocket / async audio delivery path.
|
||||
|
||||
Because audio is delivered from a background task after asyncio.sleep(), the
|
||||
skipped frame arrives at _push_frame_respecting_previous_aggregated_frame
|
||||
*before* the spoken slot's word timestamps have been processed, directly
|
||||
exercising the hold-and-flush path.
|
||||
"""
|
||||
tts = _MockWordTimestampWSTTSService(skip_aggregator_types=["code"])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
|
||||
AggregatedTextFrame("some code", "code"),
|
||||
],
|
||||
)
|
||||
down = frames_received[0]
|
||||
|
||||
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
|
||||
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
|
||||
|
||||
assert [f.text for f in word_frames] == ["hello", "world"]
|
||||
assert len(skipped) == 1
|
||||
assert skipped[0].append_to_context is True
|
||||
|
||||
last_word_idx = max(down.index(f) for f in word_frames)
|
||||
skipped_idx = down.index(skipped[0])
|
||||
assert skipped_idx > last_word_idx, (
|
||||
f"Skipped frame (pos {skipped_idx}) must appear after last word frame (pos {last_word_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skipped_frame_before_spoken_emits_immediately():
|
||||
"""A skipped frame with no preceding spoken slot is emitted right away.
|
||||
|
||||
Sequence:
|
||||
AggregatedTextFrame("some code", "code") — no spoken slot before it → emits now
|
||||
AggregatedTextFrame("hello world", SENTENCE) — spoken; TTSTextFrames follow
|
||||
|
||||
Expected: AggregatedTextFrame("some code") appears *before* TTSTextFrame("hello").
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService(skip_aggregator_types=["code"])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("some code", "code"),
|
||||
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
|
||||
],
|
||||
)
|
||||
down = frames_received[0]
|
||||
|
||||
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
|
||||
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
|
||||
|
||||
assert len(skipped) == 1
|
||||
assert skipped[0].append_to_context is True
|
||||
assert len(word_frames) >= 1
|
||||
|
||||
skipped_idx = down.index(skipped[0])
|
||||
first_word_idx = down.index(word_frames[0])
|
||||
assert skipped_idx < first_word_idx, (
|
||||
f"Skipped frame (pos {skipped_idx}) must appear before first word frame (pos {first_word_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skipped_frame_flushed_when_word_timestamps_incomplete():
|
||||
"""Force-complete path: skipped frame still emits when the TTS drops word timestamps.
|
||||
|
||||
Only one of the two expected tokens ("hello") is returned. The spoken slot never
|
||||
reaches its expected character count through the normal path. When
|
||||
on_audio_context_done fires it force-completes any remaining spoken slots and
|
||||
flushes the waiting skipped frame.
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService(
|
||||
word_times=[("hello", 0.0)], # "world" is never sent
|
||||
skip_aggregator_types=["code"],
|
||||
)
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
|
||||
AggregatedTextFrame("some code", "code"),
|
||||
],
|
||||
)
|
||||
down = frames_received[0]
|
||||
|
||||
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
|
||||
assert len(skipped) == 1, "Skipped frame must be flushed via force-complete safety net"
|
||||
assert skipped[0].append_to_context is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: raw_text propagation through WordCompletionTracker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_text_propagated_to_tts_text_frames():
|
||||
"""raw_text on AggregatedTextFrame is split across TTSTextFrames by the tracker.
|
||||
|
||||
The frame carries raw_text="<card>4111 1111</card>" while the TTS-prepared
|
||||
text is "4111 1111". The WordCompletionTracker advances a cursor through the
|
||||
raw text in step with incoming word tokens, so each TTSTextFrame receives the
|
||||
exact raw span it represents.
|
||||
|
||||
Expected (trailing whitespace stripped because includes_inter_frame_spaces=False):
|
||||
TTSTextFrame("4111").raw_text == "<card>4111"
|
||||
TTSTextFrame("1111").raw_text == "1111</card>"
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService()
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame(
|
||||
"4111 1111", AggregationType.SENTENCE, raw_text="<card>4111 1111</card>"
|
||||
)
|
||||
],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["4111", "1111"]
|
||||
# get_raw_consumed() strips trailing whitespace when includes_inter_frame_spaces=False
|
||||
assert word_frames[0].raw_text == "<card>4111"
|
||||
assert word_frames[1].raw_text == "1111</card>"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: overflow — TTS word spanning two AggregatedTextFrame boundaries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overflow_word_spanning_two_aggregated_frames():
|
||||
"""A single TTS token straddling two AggregatedTextFrame boundaries produces
|
||||
two correctly-attributed TTSTextFrames.
|
||||
|
||||
Setup:
|
||||
Frame 1: AggregatedTextFrame("abc", SENTENCE)
|
||||
Frame 2: AggregatedTextFrame("def", SENTENCE)
|
||||
|
||||
The TTS for frame 1 returns the single token "abcdef", which overshoots
|
||||
frame 1 by three characters. _emit_overflow_word splits it:
|
||||
TTSTextFrame("abc") — frame 1's portion (context_id = ctx1)
|
||||
TTSTextFrame("def") — overflow attributed to frame 2 (context_id = ctx2)
|
||||
|
||||
Frame 2 receives no word-timestamp events because the overflow already
|
||||
consumed its expected text.
|
||||
"""
|
||||
tts = _MockPerCallWordTimestampHttpTTSService(
|
||||
word_times_per_call=[
|
||||
[("abcdef", 0.0)], # frame 1: single token spanning both frames
|
||||
[], # frame 2: no word timestamps (overflow already covered it)
|
||||
]
|
||||
)
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("abc", AggregationType.SENTENCE),
|
||||
AggregatedTextFrame("def", AggregationType.SENTENCE),
|
||||
],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["abc", "def"], (
|
||||
f"Expected ['abc', 'def'] but got {[f.text for f in word_frames]}"
|
||||
)
|
||||
assert word_frames[0].context_id != word_frames[1].context_id, (
|
||||
"Overflow TTSTextFrame must carry frame 2's context_id, not frame 1's"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call word-timestamp mock for WebSocket path (for force-complete tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MockPerCallWordTimestampWSTTSService(TTSService):
|
||||
"""WebSocket-style TTS where each run_tts() call consumes its own word-time list.
|
||||
|
||||
Mirrors _MockPerCallWordTimestampHttpTTSService but uses the async audio-context
|
||||
delivery pattern so it exercises _handle_audio_context (the WebSocket path).
|
||||
An empty inner list means no word-timestamp events are emitted for that call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
word_times_per_call: list[list[tuple[str, float]]],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=False,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
self._word_times_queue = list(word_times_per_call)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
word_times = self._word_times_queue.pop(0) if self._word_times_queue else []
|
||||
|
||||
async def _deliver():
|
||||
await asyncio.sleep(0.01)
|
||||
if word_times:
|
||||
await self.add_word_timestamps(word_times, context_id=context_id)
|
||||
await self.append_to_audio_context(
|
||||
context_id,
|
||||
TTSAudioRawFrame(
|
||||
audio=_FAKE_AUDIO,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
),
|
||||
)
|
||||
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
|
||||
await self.remove_audio_context(context_id)
|
||||
|
||||
self.create_task(_deliver(), name=f"mock_ws_per_call_deliver_{context_id}")
|
||||
if False:
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _force_complete_spoken_slots — TTSTextFrame emission for dropped timestamps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_force_complete_partial_timestamps_emits_remaining_text():
|
||||
"""_force_complete_spoken_slots emits a TTSTextFrame for the unspoken word suffix.
|
||||
|
||||
Only the first token ("hello") is delivered as a word-timestamp event; "world"
|
||||
is never sent. When the audio context ends _force_complete_spoken_slots fires,
|
||||
reads get_remaining_text() from the tracker, and emits TTSTextFrame("world").
|
||||
|
||||
Expected TTSTextFrames in order: ["hello", "world"].
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService(word_times=[("hello", 0.0)])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["hello", "world"], (
|
||||
f"Expected ['hello', 'world'] but got {[f.text for f in word_frames]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_force_complete_no_timestamps_emits_full_text():
|
||||
"""_force_complete_spoken_slots emits the full text when no word timestamps arrive.
|
||||
|
||||
No word-timestamp events are sent for "hello world". The slot remains incomplete
|
||||
when the audio context ends; force-complete reads the full remaining text from the
|
||||
tracker and emits TTSTextFrame("hello world").
|
||||
"""
|
||||
tts = _MockPerCallWordTimestampHttpTTSService(word_times_per_call=[[]])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert len(word_frames) == 1, (
|
||||
f"Expected exactly 1 TTSTextFrame, got {len(word_frames)}: {[f.text for f in word_frames]}"
|
||||
)
|
||||
assert word_frames[0].text == "hello world", (
|
||||
f"Expected TTSTextFrame('hello world'), got {word_frames[0].text!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_force_complete_raw_text_propagated():
|
||||
"""force-complete carries the correct raw_text span on the emitted TTSTextFrame.
|
||||
|
||||
AggregatedTextFrame carries raw_text="<card>4111 1111</card>". Only "4111" arrives
|
||||
as a word-timestamp; "1111" is force-completed.
|
||||
|
||||
Expected:
|
||||
TTSTextFrame("4111").raw_text == "<card>4111" — from normal word path
|
||||
TTSTextFrame("1111").raw_text == "1111</card>" — from force-complete path
|
||||
"""
|
||||
tts = _MockPerCallWordTimestampHttpTTSService(word_times_per_call=[[("4111", 0.0)]])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame(
|
||||
"4111 1111", AggregationType.SENTENCE, raw_text="<card>4111 1111</card>"
|
||||
)
|
||||
],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["4111", "1111"], (
|
||||
f"Expected ['4111', '1111'] but got {[f.text for f in word_frames]}"
|
||||
)
|
||||
assert word_frames[0].raw_text == "<card>4111", (
|
||||
f"Expected raw_text '<card>4111' on first frame, got {word_frames[0].raw_text!r}"
|
||||
)
|
||||
assert word_frames[1].raw_text == "1111</card>", (
|
||||
f"Expected raw_text '1111</card>' on force-complete frame, got {word_frames[1].raw_text!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_force_complete_partial_timestamps_emits_remaining_text():
|
||||
"""WebSocket path: _force_complete_spoken_slots emits TTSTextFrame for dropped token.
|
||||
|
||||
Mirrors test_http_force_complete_partial_timestamps_emits_remaining_text on the
|
||||
async audio delivery path to confirm force-complete fires correctly from
|
||||
_handle_audio_context when TTSStoppedFrame arrives before all word timestamps.
|
||||
"""
|
||||
tts = _MockWordTimestampWSTTSService(word_times=[("hello", 0.0)])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["hello", "world"], (
|
||||
f"Expected ['hello', 'world'] but got {[f.text for f in word_frames]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_force_complete_no_timestamps_emits_full_text():
|
||||
"""WebSocket path: full text emitted as single TTSTextFrame when no timestamps arrive."""
|
||||
tts = _MockPerCallWordTimestampWSTTSService(word_times_per_call=[[]])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert len(word_frames) == 1, (
|
||||
f"Expected exactly 1 TTSTextFrame, got {len(word_frames)}: {[f.text for f in word_frames]}"
|
||||
)
|
||||
assert word_frames[0].text == "hello world", (
|
||||
f"Expected TTSTextFrame('hello world'), got {word_frames[0].text!r}"
|
||||
)
|
||||
assert all(f.includes_inter_frame_spaces is True for f in text_frames)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
1602
tests/test_word_completion_tracker.py
Normal file
1602
tests/test_word_completion_tracker.py
Normal file
File diff suppressed because it is too large
Load Diff
90
tests/test_word_timestamp_utils.py
Normal file
90
tests/test_word_timestamp_utils.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.utils.text.word_timestamp_utils import merge_punct_tokens
|
||||
|
||||
|
||||
class TestMergePunctTokens(unittest.TestCase):
|
||||
def test_empty_list(self):
|
||||
self.assertEqual(merge_punct_tokens([]), [])
|
||||
|
||||
def test_all_alnum_words_pass_through(self):
|
||||
input = [("hello", 0.0), ("world", 1.0)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("hello", 0.0), ("world", 1.0)])
|
||||
|
||||
def test_trailing_space_merged_and_stripped(self):
|
||||
input = [("I", 0.0), (" ", 0.2)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("I", 0.0)])
|
||||
|
||||
def test_comma_space_merged_and_stripped(self):
|
||||
input = [("questions", 1.0), (", ", 1.2), ("explain", 1.4)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("questions,", 1.0), ("explain", 1.4)])
|
||||
|
||||
def test_leading_space_with_no_preceding_word_discarded(self):
|
||||
input = [(" ", 0.0), ("hello", 0.5)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("hello", 0.5)])
|
||||
|
||||
def test_leading_empty_string_discarded(self):
|
||||
input = [("", 0.0), ("hello", 0.5)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("hello", 0.5)])
|
||||
|
||||
def test_multiple_consecutive_punct_tokens_merged_and_stripped(self):
|
||||
input = [("word", 0.0), (",", 0.1), (" ", 0.2), ("next", 0.3)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("word,", 0.0), ("next", 0.3)])
|
||||
|
||||
def test_timestamp_of_preceding_word_is_kept(self):
|
||||
"""Merged punct tokens adopt the preceding word's timestamp."""
|
||||
input = [("hello", 2.5), (",", 2.7)]
|
||||
result = merge_punct_tokens(input)
|
||||
self.assertEqual(result, [("hello,", 2.5)])
|
||||
|
||||
def test_xml_tag_only_token_is_treated_as_punct(self):
|
||||
"""A token that is only an XML tag (no alnum chars) merges into the preceding word."""
|
||||
input = [("word", 0.0), ("<break/>", 0.1), ("next", 0.3)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("word<break/>", 0.0), ("next", 0.3)])
|
||||
|
||||
def test_xml_tag_with_alnum_content_passes_through(self):
|
||||
"""A token like '<spell>123</spell>' has alnum chars after stripping tags."""
|
||||
input = [("<spell>123</spell>", 0.0), ("and", 0.5)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("<spell>123</spell>", 0.0), ("and", 0.5)])
|
||||
|
||||
def test_inworld_style_full_stream(self):
|
||||
"""Full Inworld-style raw stream produces expected merged and stripped output."""
|
||||
raw = [
|
||||
("", 0.0),
|
||||
("I", 0.1),
|
||||
(" ", 0.2),
|
||||
("can", 0.3),
|
||||
(" ", 0.4),
|
||||
("answer", 0.5),
|
||||
(" ", 0.6),
|
||||
("questions", 0.7),
|
||||
(", ", 0.8),
|
||||
("explain", 0.9),
|
||||
(" ", 1.0),
|
||||
("things", 1.1),
|
||||
(".", 1.2),
|
||||
]
|
||||
expected = [
|
||||
("I", 0.1),
|
||||
("can", 0.3),
|
||||
("answer", 0.5),
|
||||
("questions,", 0.7),
|
||||
("explain", 0.9),
|
||||
("things.", 1.1),
|
||||
]
|
||||
self.assertEqual(merge_punct_tokens(raw), expected)
|
||||
|
||||
def test_only_punct_tokens_returns_empty(self):
|
||||
"""A list containing only punct/space tokens produces an empty result."""
|
||||
input = [(" ", 0.0), (",", 0.1), (".", 0.2)]
|
||||
self.assertEqual(merge_punct_tokens(input), [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user