From 81bb81c1d079771c958f2b74c65706fe1d68740a Mon Sep 17 00:00:00 2001 From: filipi87 Date: Wed, 20 May 2026 10:03:26 -0300 Subject: [PATCH] test: add automated tests for word tracking, frame sequencing, and Cartesia TTS Adds tests for AggregatedFrameSequencer, WordCompletionTracker, and word_timestamp_utils (including CJK language scenarios). Updates existing Cartesia TTS and TTS frame ordering tests to cover the new behaviours. --- tests/test_aggregated_frame_sequencer.py | 612 +++++++++ tests/test_cartesia_tts.py | 4 +- tests/test_tts_frame_ordering.py | 486 ++++++- tests/test_word_completion_tracker.py | 1602 ++++++++++++++++++++++ tests/test_word_timestamp_utils.py | 90 ++ 5 files changed, 2767 insertions(+), 27 deletions(-) create mode 100644 tests/test_aggregated_frame_sequencer.py create mode 100644 tests/test_word_completion_tracker.py create mode 100644 tests/test_word_timestamp_utils.py diff --git a/tests/test_aggregated_frame_sequencer.py b/tests/test_aggregated_frame_sequencer.py new file mode 100644 index 000000000..7d12567e4 --- /dev/null +++ b/tests/test_aggregated_frame_sequencer.py @@ -0,0 +1,612 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Tests for AggregatedFrameSequencer. + +All methods on the sequencer are synchronous and return lists of frames, +so no async machinery is needed here. + +Test groups: +- register_skipped: immediate flush vs. blocked by a preceding spoken slot +- register_spoken / complete_spoken_slot: push_text_frames=True path +- flush: pts propagation, transport_destination, stops at incomplete spoken slot +- process_word: normal, completing, passthrough, raw_text propagation +- process_word overflow: single token spanning two slot boundaries +- process_word force-complete via belongs_here failure +- force_complete: remaining text emission, raw_text, corrupt raw discard, slot ordering +- clear: resets all state +""" + +import unittest + +from pipecat.frames.frames import AggregatedTextFrame, AggregationType, TTSTextFrame +from pipecat.utils.context.aggregated_frame_sequencer import AggregatedFrameSequencer +from pipecat.utils.context.word_completion_tracker import WordCompletionTracker + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _seq() -> AggregatedFrameSequencer: + return AggregatedFrameSequencer(name="test") + + +def _spoken_frame(text: str) -> AggregatedTextFrame: + return AggregatedTextFrame(text, AggregationType.SENTENCE) + + +def _skipped_frame(text: str) -> AggregatedTextFrame: + return AggregatedTextFrame(text, "code") + + +def _tracker(tts_text: str, llm_text: str | None = None) -> WordCompletionTracker: + return WordCompletionTracker(tts_text, llm_text=llm_text) + + +# --------------------------------------------------------------------------- +# register_skipped +# --------------------------------------------------------------------------- + + +class TestRegisterSkipped(unittest.TestCase): + def test_emits_immediately_with_empty_queue(self): + seq = _seq() + frame = _skipped_frame("code block") + result = seq.register_skipped(frame, "ctx1", None) + self.assertEqual(len(result), 1) + self.assertIs(result[0], frame) + + def test_sets_append_to_context_true(self): + seq = _seq() + frame = _skipped_frame("code") + seq.register_skipped(frame, "ctx1", None) + self.assertTrue(frame.append_to_context) + + def test_sets_context_id_on_frame(self): + seq = _seq() + frame = _skipped_frame("code") + seq.register_skipped(frame, "ctx42", None) + self.assertEqual(frame.context_id, "ctx42") + + def test_sets_transport_destination(self): + seq = _seq() + frame = _skipped_frame("code") + result = seq.register_skipped(frame, "ctx1", "dest-A") + self.assertEqual(result[0].transport_destination, "dest-A") + + def test_blocked_by_incomplete_spoken_slot(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True) + result = seq.register_skipped(_skipped_frame("code"), "ctx2", None) + self.assertEqual(result, []) + + def test_emits_immediately_after_already_complete_spoken_slot(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hi"), "ctx1", tracker=None, append_to_context=True) + seq.complete_spoken_slot() + result = seq.register_skipped(_skipped_frame("code"), "ctx2", None) + self.assertEqual(len(result), 1) + + def test_multiple_skipped_before_any_spoken_all_emit(self): + seq = _seq() + r1 = seq.register_skipped(_skipped_frame("code1"), "ctx1", None) + r2 = seq.register_skipped(_skipped_frame("code2"), "ctx2", None) + self.assertEqual(len(r1), 1) + self.assertEqual(len(r2), 1) + + +# --------------------------------------------------------------------------- +# register_spoken / complete_spoken_slot (push_text_frames=True path) +# --------------------------------------------------------------------------- + + +class TestCompleteSpokenSlot(unittest.TestCase): + def test_noop_with_empty_queue(self): + seq = _seq() + self.assertEqual(seq.complete_spoken_slot(), []) + + def test_marks_slot_complete_and_flushes_skipped(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx2", None) # blocked + + result = seq.complete_spoken_slot() + self.assertEqual(len(result), 1) + self.assertIs(result[0], skipped) + self.assertTrue(skipped.append_to_context) + + def test_only_first_pending_slot_is_marked(self): + seq = _seq() + seq.register_spoken(_spoken_frame("one"), "ctx1", tracker=None, append_to_context=True) + seq.register_spoken(_spoken_frame("two"), "ctx2", tracker=None, append_to_context=True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx3", None) + + # ctx2 still blocks the skipped frame + result = seq.complete_spoken_slot() + self.assertEqual(result, []) + + def test_skipped_flushes_after_all_preceding_spoken_complete(self): + seq = _seq() + seq.register_spoken(_spoken_frame("one"), "ctx1", tracker=None, append_to_context=True) + seq.register_spoken(_spoken_frame("two"), "ctx2", tracker=None, append_to_context=True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx3", None) + + seq.complete_spoken_slot() # completes ctx1 + result = seq.complete_spoken_slot() # completes ctx2 → flush skipped + self.assertEqual(len(result), 1) + self.assertIs(result[0], skipped) + + +# --------------------------------------------------------------------------- +# flush +# --------------------------------------------------------------------------- + + +class TestFlush(unittest.TestCase): + def test_empty_queue_returns_empty(self): + self.assertEqual(_seq().flush(), []) + + def test_stops_at_incomplete_spoken_slot(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True) + seq.register_skipped(_skipped_frame("code"), "ctx2", None) + self.assertEqual(seq.flush(), []) + + def test_last_word_pts_assigned_to_skipped_frame(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx2", None) + + # process_word("hello") completes the spoken slot and calls flush(last_word_pts=77) + result = seq.process_word("hello", pts=77, context_id="ctx1") + flushed = [f for f in result if isinstance(f, AggregatedTextFrame) and f.text == "code"] + self.assertEqual(len(flushed), 1) + self.assertEqual(flushed[0].pts, 77) + + def test_complete_spoken_slots_are_swept(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True) + seq.complete_spoken_slot() + # Queue should be empty after sweeping the complete spoken slot + self.assertEqual(seq._slots, []) + + +# --------------------------------------------------------------------------- +# process_word — basic +# --------------------------------------------------------------------------- + + +class TestProcessWordBasic(unittest.TestCase): + def _seq_with_spoken(self, text: str, ctx: str = "ctx1", append: bool = True): + seq = _seq() + seq.register_spoken(_spoken_frame(text), ctx, _tracker(text), append) + return seq + + def test_returns_tts_text_frame(self): + seq = self._seq_with_spoken("hello") + result = seq.process_word("hello", pts=100, context_id="ctx1") + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], TTSTextFrame) + + def test_frame_text_and_pts(self): + seq = self._seq_with_spoken("hello") + result = seq.process_word("hello", pts=100, context_id="ctx1") + self.assertEqual(result[0].text, "hello") + self.assertEqual(result[0].pts, 100) + + def test_frame_context_id(self): + seq = self._seq_with_spoken("hello", ctx="ctx99") + result = seq.process_word("hello", pts=1, context_id="ctx99") + self.assertEqual(result[0].context_id, "ctx99") + + def test_append_to_context_true(self): + seq = self._seq_with_spoken("hello", append=True) + result = seq.process_word("hello", pts=1, context_id="ctx1") + self.assertTrue(result[0].append_to_context) + + def test_append_to_context_false(self): + seq = self._seq_with_spoken("hello", append=False) + result = seq.process_word("hello", pts=1, context_id="ctx1") + self.assertFalse(result[0].append_to_context) + + def test_non_completing_word_does_not_flush_skipped(self): + seq = self._seq_with_spoken("hello world") + seq.register_skipped(_skipped_frame("code"), "ctx2", None) + result = seq.process_word("hello", pts=10, context_id="ctx1") + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], TTSTextFrame) + + def test_completing_word_flushes_blocked_skipped_frame(self): + seq = self._seq_with_spoken("hello") + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx2", None) + result = seq.process_word("hello", pts=50, context_id="ctx1") + self.assertEqual(len(result), 2) + self.assertIsInstance(result[0], TTSTextFrame) + self.assertIs(result[1], skipped) + + def test_last_of_multiple_words_flushes_skipped(self): + seq = self._seq_with_spoken("hello world") + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx2", None) + seq.process_word("hello", pts=10, context_id="ctx1") + result = seq.process_word("world", pts=20, context_id="ctx1") + self.assertTrue(any(f is skipped for f in result)) + + def test_no_active_slot_emits_passthrough(self): + seq = _seq() + result = seq.process_word("hello", pts=1, context_id="ctx-unknown") + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], TTSTextFrame) + self.assertEqual(result[0].text, "hello") + self.assertEqual(result[0].context_id, "ctx-unknown") + + def test_passthrough_uses_default_append_to_context_true(self): + seq = _seq() + result = seq.process_word("hello", pts=1, context_id="ctx-unknown") + self.assertTrue(result[0].append_to_context) + + def test_unrecognised_word_emits_passthrough(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True) + # "zzz" doesn't belong to "hello world" and there is no next slot + result = seq.process_word("zzz", pts=5, context_id="ctx1") + self.assertEqual(len(result), 1) + self.assertEqual(result[0].text, "zzz") + + +# --------------------------------------------------------------------------- +# process_word — raw_text propagation +# --------------------------------------------------------------------------- + + +class TestProcessWordRawText(unittest.TestCase): + def test_raw_text_split_across_word_frames(self): + seq = _seq() + seq.register_spoken( + _spoken_frame("4111 1111"), + "ctx1", + WordCompletionTracker("4111 1111", llm_text="4111 1111"), + append_to_context=True, + ) + r1 = seq.process_word("4111", pts=10, context_id="ctx1") + r2 = seq.process_word("1111", pts=20, context_id="ctx1") + self.assertEqual(r1[0].raw_text, "4111") + last_word_frames = [f for f in r2 if isinstance(f, TTSTextFrame)] + self.assertEqual(last_word_frames[0].raw_text, "1111") + + def test_raw_text_none_when_no_llm_text(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + result = seq.process_word("hello", pts=1, context_id="ctx1") + self.assertIsNone(result[0].raw_text) + + +# --------------------------------------------------------------------------- +# process_word — overflow (single token spanning two slots) +# --------------------------------------------------------------------------- + + +class TestProcessWordOverflow(unittest.TestCase): + def test_overflow_produces_two_tts_text_frames(self): + seq = _seq() + seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True) + seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True) + + result = seq.process_word("abcdef", pts=100, context_id="ctx1") + word_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(len(word_frames), 2) + self.assertEqual(word_frames[0].text, "abc") + self.assertEqual(word_frames[1].text, "def") + + def test_overflow_assigns_correct_context_ids(self): + seq = _seq() + seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True) + seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True) + + result = seq.process_word("abcdef", pts=100, context_id="ctx1") + word_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(word_frames[0].context_id, "ctx1") + self.assertEqual(word_frames[1].context_id, "ctx2") + + def test_overflow_completing_next_slot_flushes_skipped(self): + seq = _seq() + seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True) + seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx3", None) # blocked behind ctx2 + + result = seq.process_word("abcdef", pts=100, context_id="ctx1") + self.assertTrue(any(f is skipped for f in result)) + + def test_overflow_not_completing_next_slot_does_not_flush_skipped(self): + seq = _seq() + seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True) + seq.register_spoken(_spoken_frame("def ghi"), "ctx2", _tracker("def ghi"), True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx3", None) + + # "abcdef" overflows: "def" goes to ctx2, but ctx2 still expects " ghi" + result = seq.process_word("abcdef", pts=100, context_id="ctx1") + self.assertFalse(any(f is skipped for f in result)) + + +# --------------------------------------------------------------------------- +# process_word — force-complete via word_belongs_here failure +# --------------------------------------------------------------------------- + + +class TestProcessWordForcesComplete(unittest.TestCase): + def test_word_for_next_slot_force_completes_current(self): + """When a word belongs to the next slot but not the current, the current + slot is force-completed and the word is routed to the next slot.""" + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True) + + # "world" doesn't belong to ctx1 but belongs to ctx2 + result = seq.process_word("world", pts=50, context_id="ctx2") + word_frames = [f for f in result if isinstance(f, TTSTextFrame)] + texts = {f.text for f in word_frames} + self.assertIn("world", texts) + + def test_force_complete_then_overflow_flushes_skipped(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx3", None) + + # "world" force-completes ctx1 and completes ctx2 via overflow + result = seq.process_word("world", pts=50, context_id="ctx2") + self.assertTrue(any(f is skipped for f in result)) + + +# --------------------------------------------------------------------------- +# force_complete +# --------------------------------------------------------------------------- + + +class TestForceComplete(unittest.TestCase): + def test_emits_remaining_text_when_word_dropped(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True) + seq.process_word("hello", pts=10, context_id="ctx1") # "world" never arrives + + result = seq.force_complete(last_word_pts=50) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(len(tts_frames), 1) + self.assertEqual(tts_frames[0].text, "world") + self.assertEqual(tts_frames[0].pts, 50) + + def test_emits_full_text_when_no_words_arrived(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True) + + result = seq.force_complete(last_word_pts=0) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(len(tts_frames), 1) + self.assertEqual(tts_frames[0].text, "hello world") + + def test_already_complete_slot_emits_nothing(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hi"), "ctx1", _tracker("hi"), True) + seq.process_word("hi", pts=5, context_id="ctx1") # completes normally + + result = seq.force_complete(last_word_pts=10) + self.assertEqual(result, []) + + def test_flushes_skipped_frames_after_completing(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx2", None) + + result = seq.force_complete(last_word_pts=20) + self.assertTrue(any(f is skipped for f in result)) + self.assertTrue(skipped.append_to_context) + + def test_propagates_raw_text(self): + seq = _seq() + seq.register_spoken( + _spoken_frame("4111 1111"), + "ctx1", + WordCompletionTracker("4111 1111", llm_text="4111 1111"), + append_to_context=True, + ) + seq.process_word("4111", pts=10, context_id="ctx1") # "1111" never arrives + + result = seq.force_complete(last_word_pts=20) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(tts_frames[0].text, "1111") + self.assertEqual(tts_frames[0].raw_text, "1111") + + def test_discards_corrupt_raw_remaining(self): + """raw_remaining is discarded when it does not contain remaining_text.""" + seq = _seq() + # "abc" normalized ≠ "xyz" normalized — any remaining won't be in raw_remaining + seq.register_spoken( + _spoken_frame("abc"), + "ctx1", + WordCompletionTracker("abc", llm_text="xyz"), + append_to_context=True, + ) + result = seq.force_complete(last_word_pts=0) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(len(tts_frames), 1) + self.assertEqual(tts_frames[0].text, "abc") + self.assertIsNone(tts_frames[0].raw_text) # discarded due to corruption + + def test_slot_without_tracker_just_marks_complete_and_flushes(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True) + skipped = _skipped_frame("code") + seq.register_skipped(skipped, "ctx2", None) + + result = seq.force_complete(last_word_pts=0) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(tts_frames, []) # no tracker → no word frame + self.assertTrue(any(f is skipped for f in result)) + + def test_multiple_incomplete_slots_all_emitted(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True) + + result = seq.force_complete(last_word_pts=0) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + texts = {f.text for f in tts_frames} + self.assertIn("hello", texts) + self.assertIn("world", texts) + + +# --------------------------------------------------------------------------- +# clear +# --------------------------------------------------------------------------- + + +class TestClear(unittest.TestCase): + def test_clears_slots(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + seq.register_skipped(_skipped_frame("code"), "ctx2", None) + seq.clear() + self.assertEqual(seq._slots, []) + + def test_clears_context_map(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + seq.clear() + self.assertEqual(seq._context_append_to_context, {}) + + def test_after_clear_skipped_emits_immediately(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + seq.clear() + frame = _skipped_frame("code") + result = seq.register_skipped(frame, "ctx2", None) + self.assertEqual(len(result), 1) + + def test_after_clear_process_word_uses_passthrough(self): + seq = _seq() + seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True) + seq.clear() + result = seq.process_word("hello", pts=1, context_id="ctx1") + # No active slot after clear → passthrough + self.assertEqual(len(result), 1) + self.assertEqual(result[0].text, "hello") + + +# --------------------------------------------------------------------------- +# CJK languages — Korean, Japanese, Chinese +# --------------------------------------------------------------------------- + + +class TestCJKLanguages(unittest.TestCase): + """Sequencer behaviour for CJK language scenarios. + + Korean: Cartesia returns each word as a separate timestamp event (one word + per process_word call). Japanese/Chinese: Cartesia merges all characters + in one timestamp message into a single combined token before calling + process_word. + """ + + # --- Korean --- + + def test_korean_word_by_word_completes_slot_and_flushes_skipped(self): + """Korean words fed one at a time complete the spoken slot and unblock a skipped frame.""" + seq = _seq() + sentence = "저는 여러분의 AI 어시스턴트입니다." + words = ["저는", "여러분의", "AI", "어시스턴트입니다."] + seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True) + skipped = _skipped_frame("[code]") + seq.register_skipped(skipped, "ctx2", None) + + # Skipped stays blocked until the last word arrives + for word in words[:-1]: + partial = seq.process_word(word, pts=100, context_id="ctx1") + self.assertFalse(any(f is skipped for f in partial)) + + result = seq.process_word(words[-1], pts=200, context_id="ctx1") + self.assertTrue(any(f is skipped for f in result)) + + def test_korean_force_complete_emits_correct_remaining_text(self): + """After one Korean word, force_complete emits the correct unspoken suffix.""" + seq = _seq() + sentence = "저는 여러분의 AI 어시스턴트입니다." + seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True) + seq.process_word("저는", pts=10, context_id="ctx1") + + result = seq.force_complete(last_word_pts=50) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(len(tts_frames), 1) + self.assertEqual(tts_frames[0].text, "여러분의 AI 어시스턴트입니다.") + self.assertEqual(tts_frames[0].pts, 50) + + # --- Japanese --- + + def test_japanese_combined_groups_complete_spoken_slot(self): + """Two Cartesia-style combined Japanese groups complete the slot and flush skipped.""" + seq = _seq() + sentence = "こんにちは、私はあなたの" + seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True) + skipped = _skipped_frame("[skipped]") + seq.register_skipped(skipped, "ctx2", None) + + r1 = seq.process_word("こんにちは、私", pts=100, context_id="ctx1") + self.assertFalse(any(f is skipped for f in r1)) + + r2 = seq.process_word("はあなたの", pts=200, context_id="ctx1") + self.assertTrue(any(f is skipped for f in r2)) + + def test_japanese_force_complete_emits_remaining_chars(self): + """After the first Japanese combined group, force_complete emits the rest.""" + seq = _seq() + sentence = "こんにちは、私はあなたの" + seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True) + seq.process_word("こんにちは、私", pts=10, context_id="ctx1") + + result = seq.force_complete(last_word_pts=50) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(len(tts_frames), 1) + self.assertEqual(tts_frames[0].text, "はあなたの") + + # --- Chinese --- + + def test_chinese_combined_groups_complete_spoken_slot(self): + """Two Cartesia-style combined Chinese groups complete the slot and flush skipped.""" + seq = _seq() + sentence = "你好,我是你的智能" + seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True) + skipped = _skipped_frame("[skipped]") + seq.register_skipped(skipped, "ctx2", None) + + r1 = seq.process_word("你好,我是", pts=100, context_id="ctx1") + self.assertFalse(any(f is skipped for f in r1)) + + r2 = seq.process_word("你的智能", pts=200, context_id="ctx1") + self.assertTrue(any(f is skipped for f in r2)) + + def test_chinese_force_complete_emits_remaining_chars(self): + """After the first Chinese combined group, force_complete emits the rest.""" + seq = _seq() + sentence = "你好,我是你的智能" + seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True) + seq.process_word("你好,我是", pts=10, context_id="ctx1") + + result = seq.force_complete(last_word_pts=50) + tts_frames = [f for f in result if isinstance(f, TTSTextFrame)] + self.assertEqual(len(tts_frames), 1) + self.assertEqual(tts_frames[0].text, "你的智能") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cartesia_tts.py b/tests/test_cartesia_tts.py index c14dc7409..c83088481 100644 --- a/tests/test_cartesia_tts.py +++ b/tests/test_cartesia_tts.py @@ -18,7 +18,7 @@ def _service(language: str) -> CartesiaTTSService: def _process_word_timestamps( words: list[str], starts: list[float], language: str ) -> list[tuple[str, float]]: - return _service(language)._process_word_timestamps_for_language(words, starts) + return _service(language)._normalize_word_timestamps(words, starts) def _concatenate_processed_timestamps( @@ -27,7 +27,7 @@ def _concatenate_processed_timestamps( service = _service(language) text_parts = [] for words, starts in timestamp_groups: - processed_timestamps = service._process_word_timestamps_for_language(words, starts) + processed_timestamps = service._normalize_word_timestamps(words, starts) includes_inter_frame_spaces = service._word_timestamps_include_inter_frame_spaces() text_parts.extend( TextPartForConcatenation( diff --git a/tests/test_tts_frame_ordering.py b/tests/test_tts_frame_ordering.py index 445a40539..cc4abd221 100644 --- a/tests/test_tts_frame_ordering.py +++ b/tests/test_tts_frame_ordering.py @@ -21,6 +21,13 @@ repeated for each TTSSpeakFrame, with no cross-group contamination. Also covers LLM response flow with push_text_frames=True (non-word-timestamp TTS): verifies TTSTextFrame ordering relative to LLMFullResponseEndFrame. +Also covers smart-text / WordCompletionTracker features: +- Skipped frames (skip_aggregator_types) held until preceding spoken slots complete. +- raw_text on AggregatedTextFrame propagated as spans to TTSTextFrames. +- Overflow: a single TTS word straddling two AggregatedTextFrame boundaries produces + two correctly-attributed TTSTextFrames. +- Force-complete safety net: skipped frames flush even when TTS drops word timestamps. + Also covers the interruption-during-pause deadlock scenario (see test_no_deadlock_on_interrupt_*). """ @@ -50,6 +57,7 @@ from pipecat.frames.frames import ( ) from pipecat.services.tts_service import TTSService from pipecat.tests.utils import SleepFrame, run_test +from pipecat.utils.text.base_text_aggregator import AggregationType # --------------------------------------------------------------------------- # Test-only frame @@ -422,7 +430,7 @@ def _assert_group_ordering( # All frames between TTSStartedFrame and TTSStoppedFrame must be audio. mid_types = types[started_idx + 1 : stopped_idx] for t in mid_types: - assert t is TTSAudioRawFrame, ( + assert t in (TTSAudioRawFrame, TTSTextFrame), ( f"Group {foo_label!r}: unexpected frame {t.__name__!r} between " f"TTSStartedFrame and TTSStoppedFrame. Got: {type_names}" ) @@ -551,7 +559,7 @@ async def test_http_push_text_llm_response_end_after_tts_text(): @pytest.mark.asyncio async def test_http_word_timestamps_verbatim_tokens(): - """HTTP path: text, PTS order, flag, and text-before-audio are all verified. + """HTTP path: text, PTS order, and text-before-audio are all verified. Word timestamps arrive in the audio context queue before the audio frame. _handle_audio_context caches them, then flushes when the first audio frame @@ -572,7 +580,6 @@ async def test_http_word_timestamps_verbatim_tokens(): audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)] assert [f.text for f in tts_text_frames] == ["hello", "world"] - assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames) pts_values = [f.pts for f in tts_text_frames] assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), ( @@ -590,15 +597,14 @@ async def test_http_word_timestamps_verbatim_tokens(): @pytest.mark.asyncio async def test_http_word_timestamps_punctuation_tokens(): - """Verbatim punctuation tokens are preserved with flag=True; default flag is False. + """Punct-only tokens are merged into the preceding word when includes_inter_frame_spaces=True. - Models the Inworld API scenario: the TTS returns tokens exactly as sent. - Space placement rule: - - word-follows-word: space is the leading char of the next word (e.g. " world") - - word-follows-punctuation: space is the trailing char of the punctuation token - (e.g. "! "), so the following word token carries no leading space. - The flag must reach every frame and the text must not be modified. - Also acts as a regression guard that flag=False is the default. + Models the Inworld API scenario: the TTS returns separate space and punctuation + tokens. add_word_timestamps calls merge_punct_tokens when includes_inter_frame_spaces + is True, collapsing those tokens into the preceding word before the tracker sees them. + + With flag=False (default) tokens are forwarded as-is; the tracker strips leading/ + trailing whitespace from each frame word via get_word_for_frame(). """ verbatim_tokens = [ ("hello", 0.0), @@ -609,9 +615,9 @@ async def test_http_word_timestamps_punctuation_tokens(): (" you", 0.75), ("?", 0.9), ] - expected_texts = ["hello", " world", "! ", "How", " are", " you", "?"] - # With flag=True: all tokens verbatim, all frames carry the flag. + # With flag=True: punct-only tokens ("! " and "?") are merged into the preceding + # words (" world" → " world! " and " you" → " you?"), then stripped by the tracker. tts_ifs = _MockWordTimestampHttpTTSService( includes_inter_frame_spaces=True, word_times=verbatim_tokens, @@ -621,12 +627,11 @@ async def test_http_word_timestamps_punctuation_tokens(): frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)], ) text_frames_ifs = [f for f in frames_ifs[0] if isinstance(f, TTSTextFrame)] - assert [f.text for f in text_frames_ifs] == expected_texts, ( - "Verbatim tokens must not be modified" + assert [f.text for f in text_frames_ifs] == ["hello", "world!", "How", "are", "you?"], ( + "Punct-only tokens must be merged into the preceding word" ) - assert all(f.includes_inter_frame_spaces is True for f in text_frames_ifs) - # With flag=False (default): same tokens, flag must be False on every frame. + # With flag=False (default): no merging; tracker strips leading/trailing spaces. tts_plain = _MockWordTimestampHttpTTSService( word_times=verbatim_tokens, ) @@ -635,13 +640,12 @@ async def test_http_word_timestamps_punctuation_tokens(): frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)], ) text_frames_plain = [f for f in frames_plain[0] if isinstance(f, TTSTextFrame)] - assert [f.text for f in text_frames_plain] == expected_texts - assert all(f.includes_inter_frame_spaces is False for f in text_frames_plain) + assert [f.text for f in text_frames_plain] == ["hello", "world", "!", "How", "are", "you", "?"] @pytest.mark.asyncio async def test_websocket_word_timestamps_verbatim_tokens(): - """WebSocket path: _WordTimestampEntry carries verbatim text, PTS, and flag. + """WebSocket path: text, PTS order, and text-before-audio are all verified. Unlike the HTTP path the word timestamps are sent asynchronously from a background task. They arrive before the audio frame and are cached until @@ -662,7 +666,6 @@ async def test_websocket_word_timestamps_verbatim_tokens(): audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)] assert [f.text for f in tts_text_frames] == ["hello", "world"] - assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames) pts_values = [f.pts for f in tts_text_frames] assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), ( @@ -678,7 +681,7 @@ async def test_websocket_word_timestamps_verbatim_tokens(): @pytest.mark.asyncio async def test_websocket_word_timestamps_punctuation_tokens(): - """WebSocket path: verbatim punctuation tokens reach TTSTextFrame unchanged.""" + """WebSocket path: punct-only tokens are merged into the preceding word.""" verbatim_tokens = [ ("hello", 0.0), (" world", 0.15), @@ -697,10 +700,443 @@ async def test_websocket_word_timestamps_punctuation_tokens(): frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)], ) text_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] - assert [f.text for f in text_frames] == ["hello", " world", "! ", "How", " are", " you", "?"], ( - "Verbatim tokens must not be modified" + assert [f.text for f in text_frames] == ["hello", "world!", "How", "are", "you?"], ( + "Punct-only tokens must be merged into the preceding word" + ) + + +# --------------------------------------------------------------------------- +# Per-call word-timestamp mock (for overflow tests) +# --------------------------------------------------------------------------- + + +class _MockPerCallWordTimestampHttpTTSService(TTSService): + """HTTP-style TTS where each run_tts() call consumes its own word-time list. + + Designed for tests that need different word tokens per sentence. The + ``word_times_per_call`` list is consumed in order; an empty inner list means + no word-timestamp events are emitted for that call. + """ + + def __init__( + self, + word_times_per_call: list[list[tuple[str, float]]], + **kwargs, + ): + super().__init__( + push_start_frame=True, + push_stop_frames=True, + push_text_frames=False, + sample_rate=_SAMPLE_RATE, + **kwargs, + ) + self._word_times_queue = list(word_times_per_call) + + def can_generate_metrics(self) -> bool: + return False + + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: + word_times = self._word_times_queue.pop(0) if self._word_times_queue else [] + if word_times: + await self.add_word_timestamps(word_times, context_id=context_id) + yield TTSAudioRawFrame( + audio=_FAKE_AUDIO, + sample_rate=_SAMPLE_RATE, + num_channels=1, + context_id=context_id, + ) + + +# --------------------------------------------------------------------------- +# Tests: skipped frame ordering (skip_aggregator_types) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_http_skipped_frame_waits_for_spoken_words(): + """Skipped frames are held until the preceding spoken slot's word timestamps + are all processed, then flushed in order (HTTP / synchronous audio path). + + Sequence sent: + AggregatedTextFrame("hello world", SENTENCE) — spoken; yields 2 TTSTextFrames + AggregatedTextFrame("some code", "code") — in skip_aggregator_types; must wait + + Expected downstream order: + TTSTextFrame("hello") + TTSTextFrame("world") + AggregatedTextFrame("some code", append_to_context=True) + """ + tts = _MockWordTimestampHttpTTSService(skip_aggregator_types=["code"]) + frames_received = await run_test( + tts, + frames_to_send=[ + AggregatedTextFrame("hello world", AggregationType.SENTENCE), + AggregatedTextFrame("some code", "code"), + ], + ) + down = frames_received[0] + + word_frames = [f for f in down if isinstance(f, TTSTextFrame)] + skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"] + + assert [f.text for f in word_frames] == ["hello", "world"] + assert len(skipped) == 1 + assert skipped[0].append_to_context is True + + last_word_idx = max(down.index(f) for f in word_frames) + skipped_idx = down.index(skipped[0]) + assert skipped_idx > last_word_idx, ( + f"Skipped frame (pos {skipped_idx}) must appear after last word frame (pos {last_word_idx})" + ) + + +@pytest.mark.asyncio +async def test_ws_skipped_frame_waits_for_spoken_words(): + """Same ordering guarantee on the WebSocket / async audio delivery path. + + Because audio is delivered from a background task after asyncio.sleep(), the + skipped frame arrives at _push_frame_respecting_previous_aggregated_frame + *before* the spoken slot's word timestamps have been processed, directly + exercising the hold-and-flush path. + """ + tts = _MockWordTimestampWSTTSService(skip_aggregator_types=["code"]) + frames_received = await run_test( + tts, + frames_to_send=[ + AggregatedTextFrame("hello world", AggregationType.SENTENCE), + AggregatedTextFrame("some code", "code"), + ], + ) + down = frames_received[0] + + word_frames = [f for f in down if isinstance(f, TTSTextFrame)] + skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"] + + assert [f.text for f in word_frames] == ["hello", "world"] + assert len(skipped) == 1 + assert skipped[0].append_to_context is True + + last_word_idx = max(down.index(f) for f in word_frames) + skipped_idx = down.index(skipped[0]) + assert skipped_idx > last_word_idx, ( + f"Skipped frame (pos {skipped_idx}) must appear after last word frame (pos {last_word_idx})" + ) + + +@pytest.mark.asyncio +async def test_skipped_frame_before_spoken_emits_immediately(): + """A skipped frame with no preceding spoken slot is emitted right away. + + Sequence: + AggregatedTextFrame("some code", "code") — no spoken slot before it → emits now + AggregatedTextFrame("hello world", SENTENCE) — spoken; TTSTextFrames follow + + Expected: AggregatedTextFrame("some code") appears *before* TTSTextFrame("hello"). + """ + tts = _MockWordTimestampHttpTTSService(skip_aggregator_types=["code"]) + frames_received = await run_test( + tts, + frames_to_send=[ + AggregatedTextFrame("some code", "code"), + AggregatedTextFrame("hello world", AggregationType.SENTENCE), + ], + ) + down = frames_received[0] + + word_frames = [f for f in down if isinstance(f, TTSTextFrame)] + skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"] + + assert len(skipped) == 1 + assert skipped[0].append_to_context is True + assert len(word_frames) >= 1 + + skipped_idx = down.index(skipped[0]) + first_word_idx = down.index(word_frames[0]) + assert skipped_idx < first_word_idx, ( + f"Skipped frame (pos {skipped_idx}) must appear before first word frame (pos {first_word_idx})" + ) + + +@pytest.mark.asyncio +async def test_skipped_frame_flushed_when_word_timestamps_incomplete(): + """Force-complete path: skipped frame still emits when the TTS drops word timestamps. + + Only one of the two expected tokens ("hello") is returned. The spoken slot never + reaches its expected character count through the normal path. When + on_audio_context_done fires it force-completes any remaining spoken slots and + flushes the waiting skipped frame. + """ + tts = _MockWordTimestampHttpTTSService( + word_times=[("hello", 0.0)], # "world" is never sent + skip_aggregator_types=["code"], + ) + frames_received = await run_test( + tts, + frames_to_send=[ + AggregatedTextFrame("hello world", AggregationType.SENTENCE), + AggregatedTextFrame("some code", "code"), + ], + ) + down = frames_received[0] + + skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"] + assert len(skipped) == 1, "Skipped frame must be flushed via force-complete safety net" + assert skipped[0].append_to_context is True + + +# --------------------------------------------------------------------------- +# Tests: raw_text propagation through WordCompletionTracker +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_raw_text_propagated_to_tts_text_frames(): + """raw_text on AggregatedTextFrame is split across TTSTextFrames by the tracker. + + The frame carries raw_text="4111 1111" while the TTS-prepared + text is "4111 1111". The WordCompletionTracker advances a cursor through the + raw text in step with incoming word tokens, so each TTSTextFrame receives the + exact raw span it represents. + + Expected (trailing whitespace stripped because includes_inter_frame_spaces=False): + TTSTextFrame("4111").raw_text == "4111" + TTSTextFrame("1111").raw_text == "1111" + """ + tts = _MockWordTimestampHttpTTSService() + frames_received = await run_test( + tts, + frames_to_send=[ + AggregatedTextFrame( + "4111 1111", AggregationType.SENTENCE, raw_text="4111 1111" + ) + ], + ) + word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] + + assert [f.text for f in word_frames] == ["4111", "1111"] + # get_raw_consumed() strips trailing whitespace when includes_inter_frame_spaces=False + assert word_frames[0].raw_text == "4111" + assert word_frames[1].raw_text == "1111" + + +# --------------------------------------------------------------------------- +# Tests: overflow — TTS word spanning two AggregatedTextFrame boundaries +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_overflow_word_spanning_two_aggregated_frames(): + """A single TTS token straddling two AggregatedTextFrame boundaries produces + two correctly-attributed TTSTextFrames. + + Setup: + Frame 1: AggregatedTextFrame("abc", SENTENCE) + Frame 2: AggregatedTextFrame("def", SENTENCE) + + The TTS for frame 1 returns the single token "abcdef", which overshoots + frame 1 by three characters. _emit_overflow_word splits it: + TTSTextFrame("abc") — frame 1's portion (context_id = ctx1) + TTSTextFrame("def") — overflow attributed to frame 2 (context_id = ctx2) + + Frame 2 receives no word-timestamp events because the overflow already + consumed its expected text. + """ + tts = _MockPerCallWordTimestampHttpTTSService( + word_times_per_call=[ + [("abcdef", 0.0)], # frame 1: single token spanning both frames + [], # frame 2: no word timestamps (overflow already covered it) + ] + ) + frames_received = await run_test( + tts, + frames_to_send=[ + AggregatedTextFrame("abc", AggregationType.SENTENCE), + AggregatedTextFrame("def", AggregationType.SENTENCE), + ], + ) + word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] + + assert [f.text for f in word_frames] == ["abc", "def"], ( + f"Expected ['abc', 'def'] but got {[f.text for f in word_frames]}" + ) + assert word_frames[0].context_id != word_frames[1].context_id, ( + "Overflow TTSTextFrame must carry frame 2's context_id, not frame 1's" + ) + + +# --------------------------------------------------------------------------- +# Per-call word-timestamp mock for WebSocket path (for force-complete tests) +# --------------------------------------------------------------------------- + + +class _MockPerCallWordTimestampWSTTSService(TTSService): + """WebSocket-style TTS where each run_tts() call consumes its own word-time list. + + Mirrors _MockPerCallWordTimestampHttpTTSService but uses the async audio-context + delivery pattern so it exercises _handle_audio_context (the WebSocket path). + An empty inner list means no word-timestamp events are emitted for that call. + """ + + def __init__( + self, + word_times_per_call: list[list[tuple[str, float]]], + **kwargs, + ): + super().__init__( + push_start_frame=True, + push_text_frames=False, + pause_frame_processing=False, + sample_rate=_SAMPLE_RATE, + **kwargs, + ) + self._word_times_queue = list(word_times_per_call) + + def can_generate_metrics(self) -> bool: + return False + + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: + word_times = self._word_times_queue.pop(0) if self._word_times_queue else [] + + async def _deliver(): + await asyncio.sleep(0.01) + if word_times: + await self.add_word_timestamps(word_times, context_id=context_id) + await self.append_to_audio_context( + context_id, + TTSAudioRawFrame( + audio=_FAKE_AUDIO, + sample_rate=_SAMPLE_RATE, + num_channels=1, + context_id=context_id, + ), + ) + await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id)) + await self.remove_audio_context(context_id) + + self.create_task(_deliver(), name=f"mock_ws_per_call_deliver_{context_id}") + if False: + yield + + +# --------------------------------------------------------------------------- +# Tests: _force_complete_spoken_slots — TTSTextFrame emission for dropped timestamps +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_http_force_complete_partial_timestamps_emits_remaining_text(): + """_force_complete_spoken_slots emits a TTSTextFrame for the unspoken word suffix. + + Only the first token ("hello") is delivered as a word-timestamp event; "world" + is never sent. When the audio context ends _force_complete_spoken_slots fires, + reads get_remaining_text() from the tracker, and emits TTSTextFrame("world"). + + Expected TTSTextFrames in order: ["hello", "world"]. + """ + tts = _MockWordTimestampHttpTTSService(word_times=[("hello", 0.0)]) + frames_received = await run_test( + tts, + frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)], + ) + word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] + + assert [f.text for f in word_frames] == ["hello", "world"], ( + f"Expected ['hello', 'world'] but got {[f.text for f in word_frames]}" + ) + + +@pytest.mark.asyncio +async def test_http_force_complete_no_timestamps_emits_full_text(): + """_force_complete_spoken_slots emits the full text when no word timestamps arrive. + + No word-timestamp events are sent for "hello world". The slot remains incomplete + when the audio context ends; force-complete reads the full remaining text from the + tracker and emits TTSTextFrame("hello world"). + """ + tts = _MockPerCallWordTimestampHttpTTSService(word_times_per_call=[[]]) + frames_received = await run_test( + tts, + frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)], + ) + word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] + + assert len(word_frames) == 1, ( + f"Expected exactly 1 TTSTextFrame, got {len(word_frames)}: {[f.text for f in word_frames]}" + ) + assert word_frames[0].text == "hello world", ( + f"Expected TTSTextFrame('hello world'), got {word_frames[0].text!r}" + ) + + +@pytest.mark.asyncio +async def test_http_force_complete_raw_text_propagated(): + """force-complete carries the correct raw_text span on the emitted TTSTextFrame. + + AggregatedTextFrame carries raw_text="4111 1111". Only "4111" arrives + as a word-timestamp; "1111" is force-completed. + + Expected: + TTSTextFrame("4111").raw_text == "4111" — from normal word path + TTSTextFrame("1111").raw_text == "1111" — from force-complete path + """ + tts = _MockPerCallWordTimestampHttpTTSService(word_times_per_call=[[("4111", 0.0)]]) + frames_received = await run_test( + tts, + frames_to_send=[ + AggregatedTextFrame( + "4111 1111", AggregationType.SENTENCE, raw_text="4111 1111" + ) + ], + ) + word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] + + assert [f.text for f in word_frames] == ["4111", "1111"], ( + f"Expected ['4111', '1111'] but got {[f.text for f in word_frames]}" + ) + assert word_frames[0].raw_text == "4111", ( + f"Expected raw_text '4111' on first frame, got {word_frames[0].raw_text!r}" + ) + assert word_frames[1].raw_text == "1111", ( + f"Expected raw_text '1111' on force-complete frame, got {word_frames[1].raw_text!r}" + ) + + +@pytest.mark.asyncio +async def test_ws_force_complete_partial_timestamps_emits_remaining_text(): + """WebSocket path: _force_complete_spoken_slots emits TTSTextFrame for dropped token. + + Mirrors test_http_force_complete_partial_timestamps_emits_remaining_text on the + async audio delivery path to confirm force-complete fires correctly from + _handle_audio_context when TTSStoppedFrame arrives before all word timestamps. + """ + tts = _MockWordTimestampWSTTSService(word_times=[("hello", 0.0)]) + frames_received = await run_test( + tts, + frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)], + ) + word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] + + assert [f.text for f in word_frames] == ["hello", "world"], ( + f"Expected ['hello', 'world'] but got {[f.text for f in word_frames]}" + ) + + +@pytest.mark.asyncio +async def test_ws_force_complete_no_timestamps_emits_full_text(): + """WebSocket path: full text emitted as single TTSTextFrame when no timestamps arrive.""" + tts = _MockPerCallWordTimestampWSTTSService(word_times_per_call=[[]]) + frames_received = await run_test( + tts, + frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)], + ) + word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] + + assert len(word_frames) == 1, ( + f"Expected exactly 1 TTSTextFrame, got {len(word_frames)}: {[f.text for f in word_frames]}" + ) + assert word_frames[0].text == "hello world", ( + f"Expected TTSTextFrame('hello world'), got {word_frames[0].text!r}" ) - assert all(f.includes_inter_frame_spaces is True for f in text_frames) @pytest.mark.asyncio diff --git a/tests/test_word_completion_tracker.py b/tests/test_word_completion_tracker.py new file mode 100644 index 000000000..42e75970e --- /dev/null +++ b/tests/test_word_completion_tracker.py @@ -0,0 +1,1602 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.utils.context.word_completion_tracker import WordCompletionTracker + + +class TestWordCompletionTrackerBasic(unittest.TestCase): + def test_not_complete_before_any_words(self): + tracker = WordCompletionTracker("Hello world") + self.assertFalse(tracker.is_complete) + + def test_complete_after_all_words(self): + tracker = WordCompletionTracker("Hello world") + tracker.add_word_and_check_complete("Hello") + self.assertFalse(tracker.is_complete) + tracker.add_word_and_check_complete("world") + self.assertTrue(tracker.is_complete) + + def test_add_word_and_check_complete_return_value(self): + """add_word_and_check_complete returns True exactly when the tracker becomes complete.""" + tracker = WordCompletionTracker("Hi there") + self.assertFalse(tracker.add_word_and_check_complete("Hi")) + self.assertTrue(tracker.add_word_and_check_complete("there")) + + def test_single_word(self): + tracker = WordCompletionTracker("Hello") + self.assertTrue(tracker.add_word_and_check_complete("Hello")) + + def test_complete_stays_true_after_extra_words(self): + """Extra words after completion force-complete (no-op remaining) and stay complete.""" + tracker = WordCompletionTracker("Hi") + tracker.add_word_and_check_complete("Hi") + result = tracker.add_word_and_check_complete("extra") + self.assertTrue(result) + self.assertTrue(tracker.is_complete) + + +class TestWordCompletionTrackerNormalization(unittest.TestCase): + def test_punctuation_ignored_in_expected(self): + """Punctuation in the source text is stripped before comparison.""" + tracker = WordCompletionTracker("Hello, world!") + tracker.add_word_and_check_complete("Hello") + tracker.add_word_and_check_complete("world") + self.assertTrue(tracker.is_complete) + + def test_punctuation_ignored_in_words(self): + """Punctuation attached to TTS word tokens is also stripped.""" + tracker = WordCompletionTracker("Hello world") + tracker.add_word_and_check_complete("Hello,") + tracker.add_word_and_check_complete("world!") + self.assertTrue(tracker.is_complete) + + def test_case_insensitive(self): + tracker = WordCompletionTracker("HELLO WORLD") + tracker.add_word_and_check_complete("hello") + tracker.add_word_and_check_complete("world") + self.assertTrue(tracker.is_complete) + + def test_spaces_ignored(self): + """Spaces in expected or received text do not count towards char totals.""" + tracker = WordCompletionTracker("a b c") + # Normalized expected: "abc" (3 chars) + tracker.add_word_and_check_complete("a") + tracker.add_word_and_check_complete("b") + self.assertFalse(tracker.is_complete) + tracker.add_word_and_check_complete("c") + self.assertTrue(tracker.is_complete) + + def test_numbers_kept(self): + """Digits are treated as regular alphanumeric characters.""" + tracker = WordCompletionTracker("Room 42") + tracker.add_word_and_check_complete("Room") + self.assertFalse(tracker.is_complete) + tracker.add_word_and_check_complete("42") + self.assertTrue(tracker.is_complete) + + def test_mixed_punctuation_and_numbers(self): + tracker = WordCompletionTracker("It costs $9.99!") + # Normalized expected: "itcosts999" + tracker.add_word_and_check_complete("It") + tracker.add_word_and_check_complete("costs") + self.assertFalse(tracker.is_complete) + tracker.add_word_and_check_complete("$9.99!") + self.assertTrue(tracker.is_complete) + + def test_special_characters_only_in_expected(self): + """If the expected text normalizes to empty, the tracker is immediately complete.""" + tracker = WordCompletionTracker("...") + self.assertTrue(tracker.is_complete) + + def test_emoji_only_expected_accepts_emoji_word_without_warning(self): + """Emoji-only frame is already complete (normalizes to ''), but a TTS word event + for the emoji must still be accepted gracefully and return True without a warning.""" + tracker = WordCompletionTracker("😊") + self.assertTrue(tracker.is_complete) + result = tracker.add_word_and_check_complete("😊") + self.assertTrue(result) + self.assertTrue(tracker.is_complete) + + def test_special_characters_only_in_word(self): + """A word token that normalizes to empty and is absent from the expected text + triggers force-complete; one that is present in the expected text is consumed + normally and contributes nothing to the alnum count.""" + # "---" is not in "hello" → force-complete + tracker = WordCompletionTracker("hello") + tracker.add_word_and_check_complete("---") + self.assertTrue(tracker.is_complete) + self.assertEqual(tracker.get_overflow_word(), "---") + + # "..." IS in "hello..." so it belongs here and contributes 0 alnum chars + tracker2 = WordCompletionTracker("hello...") + tracker2.add_word_and_check_complete("...") # belongs, but no alnum chars → incomplete + self.assertFalse(tracker2.is_complete) + tracker2.add_word_and_check_complete("hello") + self.assertTrue(tracker2.is_complete) + + def test_ssml_tags_stripped_in_word(self): + """SSML tags like ... in TTS word tokens are stripped before comparison.""" + tracker = WordCompletionTracker("1234-5678") + # Cartesia returns something like "1234-5678" as a word token + self.assertTrue(tracker.add_word_and_check_complete("1234-5678")) + + def test_ssml_tags_do_not_inflate_char_count(self): + """Tag names must not count as alphanumeric chars, preventing premature completion.""" + tracker = WordCompletionTracker("ab") + # Without tag stripping "spell" (5 chars) would push received count over threshold. + tracker.add_word_and_check_complete("a") + self.assertFalse(tracker.is_complete) + tracker.add_word_and_check_complete("b") + self.assertTrue(tracker.is_complete) + + def test_ssml_tags_in_expected_text(self): + """Tags in the expected text are also stripped.""" + tracker = WordCompletionTracker("hello") + self.assertTrue(tracker.add_word_and_check_complete("hello")) + + def test_curly_apostrophe_in_llm_text_matches_straight_apostrophe_in_tts_word(self): + """LLM curly apostrophe must not trigger the safeguard when TTS uses straight.""" + llm = "you’re welcome" # LLM: RIGHT SINGLE QUOTATION MARK + tracker = WordCompletionTracker(llm, llm_text=llm) + tracker.add_word_and_check_complete("you’re") # TTS: straight apostrophe + self.assertIsNotNone(tracker.get_llm_consumed()) + tracker.add_word_and_check_complete("welcome") + self.assertTrue(tracker.is_complete) + + def test_curly_apostrophe_in_tts_word_matches_straight_apostrophe_in_llm_text(self): + """TTS curly apostrophe must not trigger the safeguard when LLM uses straight.""" + llm = "you’re welcome" # LLM: straight apostrophe + tracker = WordCompletionTracker(llm, llm_text=llm) + tracker.add_word_and_check_complete("you’re") # TTS: RIGHT SINGLE QUOTATION MARK + self.assertIsNotNone(tracker.get_llm_consumed()) + tracker.add_word_and_check_complete("welcome") + self.assertTrue(tracker.is_complete) + + +class TestWordCompletionTrackerReset(unittest.TestCase): + def test_reset_clears_progress(self): + tracker = WordCompletionTracker("Hello world") + tracker.add_word_and_check_complete("Hello") + tracker.add_word_and_check_complete("world") + self.assertTrue(tracker.is_complete) + tracker.reset() + self.assertFalse(tracker.is_complete) + + def test_reset_allows_reuse(self): + tracker = WordCompletionTracker("Hello world") + tracker.add_word_and_check_complete("Hello") + tracker.add_word_and_check_complete("world") + tracker.reset() + tracker.add_word_and_check_complete("Hello") + self.assertFalse(tracker.is_complete) + tracker.add_word_and_check_complete("world") + self.assertTrue(tracker.is_complete) + + def test_reset_preserves_expected_text(self): + """reset() only clears received chars; the expected text is unchanged.""" + tracker = WordCompletionTracker("Hi") + tracker.add_word_and_check_complete("Hi") + tracker.reset() + # Re-adding the same word should complete again + self.assertTrue(tracker.add_word_and_check_complete("Hi")) + + def test_reset_clears_raw_pos_cursor(self): + """reset() resets the llm_text cursor so raw_consumed is correct after replay.""" + raw = "4111" + tracker = WordCompletionTracker("4111", llm_text=raw) + tracker.add_word_and_check_complete("4111") + self.assertEqual(tracker.get_llm_consumed(), "4111") + tracker.reset() + tracker.add_word_and_check_complete("4111") + # Cursor restarts from position 0 after reset. + self.assertEqual(tracker.get_llm_consumed(), "4111") + + def test_reset_clears_expected_raw_pos_cursor(self): + """reset() resets the expected_raw cursor so force-complete uses full text again.""" + tracker = WordCompletionTracker("number is") + tracker.add_word_and_check_complete("number") + # Partially advanced: remaining = " is" + tracker.add_word_and_check_complete("4111") # force-complete + self.assertEqual(tracker.get_word_for_frame(), "is") + + tracker.reset() + # After reset the cursor is back at 0, so force-complete sees the full text. + tracker.add_word_and_check_complete("4111") + self.assertEqual(tracker.get_word_for_frame(), "number is") + + +class TestWordCompletionTrackerEdgeCases(unittest.TestCase): + def test_empty_expected_text(self): + """An empty expected string is complete from the start.""" + tracker = WordCompletionTracker("") + self.assertTrue(tracker.is_complete) + + def test_empty_word_adds_nothing(self): + tracker = WordCompletionTracker("hello") + tracker.add_word_and_check_complete("") + self.assertFalse(tracker.is_complete) + + def test_partial_word_completion(self): + """Chars accumulate; completion happens mid-add_word_and_check_complete if count is reached.""" + tracker = WordCompletionTracker("ab") + # "ab" normalizes to 2 chars; one word covering both at once + self.assertTrue(tracker.add_word_and_check_complete("ab")) + + def test_word_with_extra_chars_completes(self): + """A single verbose token can satisfy a longer expected text.""" + tracker = WordCompletionTracker("Hi") + self.assertTrue(tracker.add_word_and_check_complete("Hieveryone")) + + +class TestWordCompletionTrackerRealisticSentences(unittest.TestCase): + # Sentence as it would appear in an AggregatedTextFrame from the LLM. + SENTENCE = "You're welcome! If you have any more questions or need further examples, feel free to ask. Have a great day! 😊" + + # Words as a TTS word-timestamp service typically returns them: + # punctuation attached to the adjacent word, emoji absent (unspeakable). + TTS_WORDS = [ + "You're", + "welcome!", + "If", + "you", + "have", + "any", + "more", + "questions", + "or", + "need", + "further", + "examples,", + "feel", + "free", + "to", + "ask.", + "Have", + "a", + "great", + "day!", + ] + + def test_completes_after_all_tts_words(self): + """Feeding every TTS word in order must complete the tracker.""" + tracker = WordCompletionTracker(self.SENTENCE) + results = [tracker.add_word_and_check_complete(w) for w in self.TTS_WORDS] + self.assertTrue(results[-1], "tracker should be complete after the last word") + + def test_not_complete_before_last_word(self): + """Tracker must not report complete before the final word is added.""" + tracker = WordCompletionTracker(self.SENTENCE) + for word in self.TTS_WORDS[:-1]: + self.assertFalse(tracker.is_complete, f"should not be complete after '{word}'") + + def test_last_word_triggers_completion(self): + """Only the last add_word_and_check_complete call should return True.""" + tracker = WordCompletionTracker(self.SENTENCE) + intermediate = [tracker.add_word_and_check_complete(w) for w in self.TTS_WORDS[:-1]] + self.assertFalse(any(intermediate)) + self.assertTrue(tracker.add_word_and_check_complete(self.TTS_WORDS[-1])) + + def test_emoji_in_expected_does_not_block_completion(self): + """The 😊 at the end normalizes to '' so it adds no required chars.""" + tracker = WordCompletionTracker(self.SENTENCE) + for word in self.TTS_WORDS: + tracker.add_word_and_check_complete(word) + self.assertTrue(tracker.is_complete) + + def test_chunked_delivery_two_words_per_call(self): + """Some TTS providers return multiple words in one timestamp event.""" + tracker = WordCompletionTracker(self.SENTENCE) + pairs = [ + self.TTS_WORDS[i] + " " + self.TTS_WORDS[i + 1] + for i in range(0, len(self.TTS_WORDS) - 1, 2) + ] + # If word count is odd the last word is left alone. + if len(self.TTS_WORDS) % 2: + pairs.append(self.TTS_WORDS[-1]) + for chunk in pairs: + tracker.add_word_and_check_complete(chunk) + self.assertTrue(tracker.is_complete) + + def test_reset_and_replay(self): + """After reset the same tracker can complete the same sentence again.""" + tracker = WordCompletionTracker(self.SENTENCE) + for word in self.TTS_WORDS: + tracker.add_word_and_check_complete(word) + self.assertTrue(tracker.is_complete) + + tracker.reset() + self.assertFalse(tracker.is_complete) + + for word in self.TTS_WORDS: + tracker.add_word_and_check_complete(word) + self.assertTrue(tracker.is_complete) + + def test_mid_sentence_is_not_complete(self): + """Spot-check several points through the sentence.""" + tracker = WordCompletionTracker(self.SENTENCE) + checkpoints = {4, 9, 14} # after 5th, 10th, 15th word (0-indexed) + for i, word in enumerate(self.TTS_WORDS[:-1]): + tracker.add_word_and_check_complete(word) + if i in checkpoints: + self.assertFalse( + tracker.is_complete, + f"should not be complete after {i + 1} words", + ) + + def test_short_sentence_word_by_word(self): + """Smaller realistic sentence: each word added individually.""" + sentence = "Of course! Here's a simple Python example." + words = ["Of", "course!", "Here's", "a", "simple", "Python", "example."] + tracker = WordCompletionTracker(sentence) + for word in words[:-1]: + self.assertFalse(tracker.add_word_and_check_complete(word)) + self.assertTrue(tracker.add_word_and_check_complete(words[-1])) + + def test_sentence_with_numbers(self): + """Numbers are alphanumeric and must be counted in both directions.""" + sentence = "There are 3 options available for you." + words = ["There", "are", "3", "options", "available", "for", "you."] + tracker = WordCompletionTracker(sentence) + for word in words[:-1]: + self.assertFalse(tracker.add_word_and_check_complete(word)) + self.assertTrue(tracker.add_word_and_check_complete(words[-1])) + + def test_credit_card_sentence_with_raw_consumed(self): + """Test that get_raw_consumed returns 'code:' after completing the sentence and all words.""" + sentence = "Here is a sample credit card number and a simple Python code:" + words = [ + "Here", + "is", + "a", + "sample", + "credit", + "card", + "number", + "and", + "a", + "simple", + "Python", + "code:", + ] + + # Provide llm_text parameter so get_llm_consumed() works + tracker = WordCompletionTracker(sentence, llm_text=sentence) + + # Add all words except the last one + for word in words[:-1]: + result = tracker.add_word_and_check_complete(word) + self.assertFalse(result, f"Should not be complete after adding '{word}'") + + # Add the final word - should complete the tracker + result = tracker.add_word_and_check_complete(words[-1]) + self.assertTrue(result, "Should be complete after adding the last word") + self.assertTrue(tracker.is_complete) + + # get_raw_consumed should return "code:" for the last word + self.assertEqual(tracker.get_llm_consumed(), "code:") + + def test_maori_culture_sentence(self): + """Test completion with Māori culture sentence - last word should be 'culture.'""" + sentence = ( + "The indigenous Māori people are a significant part of the population and culture." + ) + words = [ + "The", + "indigenous", + "Māori", + "people", + "are", + "a", + "significant", + "part", + "of", + "the", + "population", + "and", + "culture.", + ] + tracker = WordCompletionTracker(sentence, llm_text=sentence) + + # Add all words except the last one - should not complete + for word in words[:-1]: + result = tracker.add_word_and_check_complete(word) + self.assertFalse(result, f"Should not be complete after adding '{word}'") + + # Add the final word "culture." - should complete the tracker + result = tracker.add_word_and_check_complete(words[-1]) + self.assertTrue(result, "Should be complete after adding 'culture.'") + self.assertTrue(tracker.is_complete) + + # get_raw_consumed should return "culture." for the last word + self.assertEqual(tracker.get_word_for_frame(), "culture.") + self.assertEqual(tracker.get_llm_consumed(), "culture.") + + def test_geography_sentence_frame_word_and_raw_consumed_validation(self): + """Test geography sentence word by word, validating _frame_word and _raw_consumed match expected values. + + Sentence: 'Here are some key facts: **Geography:** - It consists mainly of two large islands: + the North Island and the South Island, as well as many smaller islands.' + + This test validates that the received word, _frame_word, and _raw_consumed are exactly + what we expect for each word addition, especially in special cases with punctuation. + """ + sentence = "Here are some key facts: **Geography:** - It consists mainly of two large islands: the North Island and the South Island, as well as many smaller islands." + llm_text = f"{sentence}" + + words = [ + "Here", + "are", + "some", + "key", + "facts:", + "**Geography:**", + "-", + "It", + "consists", + "mainly", + "of", + "two", + "large", + "islands:", + "the", + "North", + "Island", + "and", + "the", + "South", + "Island,", + "as", + "well", + "as", + "many", + "smaller", + "islands.", + ] + + # Expected _frame_word for each word (should match the input word exactly) + expected_frame_words = [ + "Here", + "are", + "some", + "key", + "facts:", + "**Geography:**", + "-", + "It", + "consists", + "mainly", + "of", + "two", + "large", + "islands:", + "the", + "North", + "Island", + "and", + "the", + "South", + "Island,", + "as", + "well", + "as", + "many", + "smaller", + "islands.", + ] + + # Expected _llm_consumed for each word (spans from llm_text) + expected_raw_consumed = [ + "Here", + "are", + "some", + "key", + "facts:", + "**Geography:**", + "-", + "It", + "consists", + "mainly", + "of", + "two", + "large", + "islands:", + "the", + "North", + "Island", + "and", + "the", + "South", + "Island,", + "as", + "well", + "as", + "many", + "smaller", + "islands.", + ] + + tracker = WordCompletionTracker(sentence, llm_text=llm_text) + + for i, word in enumerate(words): + is_complete = tracker.add_word_and_check_complete(word) + + # Test 1: Validate _frame_word matches expected + actual_frame_word = tracker.get_word_for_frame() + expected_frame_word = expected_frame_words[i] + self.assertEqual( + actual_frame_word, + expected_frame_word, + f"Word {i + 1} '{word}': expected _frame_word '{expected_frame_word}', got '{actual_frame_word}'", + ) + + # Test 2: Validate _raw_consumed matches expected + actual_raw_consumed = tracker.get_llm_consumed() + expected_raw = expected_raw_consumed[i] + self.assertEqual( + actual_raw_consumed, + expected_raw, + f"Word {i + 1} '{word}': expected _raw_consumed '{expected_raw}', got '{actual_raw_consumed}'", + ) + + # Test 3: Validate completion status + if i == len(words) - 1: + self.assertTrue(is_complete, f"Should be complete after final word '{word}'") + else: + self.assertFalse( + is_complete, f"Should not be complete after word '{word}' (position {i + 1})" + ) + + # Test special punctuation cases individually + special_cases = [ + ("**Geography:**", "**Geography:**"), + ("facts:", "facts:"), + ("islands:", "islands:"), + ("Island,", "Island,"), + ("islands.", "islands."), + ] + + for word, expected_frame in special_cases: + tracker_special = WordCompletionTracker(word, llm_text=f"{word}") + tracker_special.add_word_and_check_complete(word) + + actual_frame = tracker_special.get_word_for_frame() + self.assertEqual( + actual_frame, + expected_frame, + f"Special case '{word}': expected _frame_word '{expected_frame}', got '{actual_frame}'", + ) + + actual_raw = tracker_special.get_llm_consumed() + expected_raw_special = f"{word}" + self.assertEqual( + actual_raw, + expected_raw_special, + f"Special case '{word}': expected _raw_consumed '{expected_raw_special}', got '{actual_raw}'", + ) + + +class TestWordCompletionTrackerWordBelongsHere(unittest.TestCase): + def test_belongs_when_word_is_prefix_of_remaining(self): + """word_belongs_here returns True when word alnum chars match the start of remaining.""" + tracker = WordCompletionTracker("4111 1111 1111 1111") + self.assertTrue(tracker.word_belongs_here("4111")) + + def test_does_not_belong_when_word_mismatches_remaining(self): + """word_belongs_here returns False when the word is clearly from a different slot.""" + tracker = WordCompletionTracker("number is") + # '4' does not match 'n' (start of "numberis") + self.assertFalse(tracker.word_belongs_here("4111")) + + def test_punctuation_present_in_expected_belongs(self): + """Punctuation that appears in the remaining expected text belongs here.""" + tracker = WordCompletionTracker("hello, world") + tracker.add_word_and_check_complete("hello") + # remaining expected = ", world" — comma is present, so it belongs + self.assertTrue(tracker.word_belongs_here(",")) + + def test_punctuation_absent_from_expected_does_not_belong(self): + """Punctuation absent from the remaining expected text does not belong here.""" + tracker = WordCompletionTracker("hello") + # "..." and "," are not in "hello", so they don't belong + self.assertFalse(tracker.word_belongs_here("...")) + self.assertFalse(tracker.word_belongs_here(",")) + + def test_empty_word_always_belongs(self): + """An empty word token always belongs (empty string is a substring of any text).""" + tracker = WordCompletionTracker("hello") + self.assertTrue(tracker.word_belongs_here("")) + + def test_emoji_in_expected_belongs(self): + """An emoji present in the remaining expected text belongs to this frame.""" + tracker = WordCompletionTracker("Hello 😊") + tracker.add_word_and_check_complete("Hello") + self.assertTrue(tracker.word_belongs_here("😊")) + + def test_emoji_absent_from_expected_does_not_belong(self): + """An emoji not in the remaining expected text is routed to the next frame.""" + tracker = WordCompletionTracker("Hello world") + tracker.add_word_and_check_complete("Hello") + # 😊 is not in "Hello world", so it doesn't belong here + self.assertFalse(tracker.word_belongs_here("😊")) + + def test_belongs_when_partial_prefix_matches(self): + """Partial token (fewer chars than remaining) still passes if it is a prefix.""" + tracker = WordCompletionTracker("hello world") + # remaining = "helloworld", word = "hel" — prefix matches + self.assertTrue(tracker.word_belongs_here("hel")) + + def test_does_not_belong_when_no_remaining(self): + """word_belongs_here returns False when the tracker is already complete.""" + tracker = WordCompletionTracker("hi") + tracker.add_word_and_check_complete("hi") + self.assertFalse(tracker.word_belongs_here("extra")) + + def test_belongs_advances_correctly_after_partial_consumption(self): + """Remaining expected text shifts as words are consumed.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + # remaining is now "world" + self.assertTrue(tracker.word_belongs_here("world")) + self.assertFalse(tracker.word_belongs_here("hello")) + + +class TestWordCompletionTrackerOverflow(unittest.TestCase): + """Words that span the boundary of two AggregatedTextFrames.""" + + def test_word_straddles_frame_boundary(self): + """A word token that spans two frames splits at the alnum boundary.""" + tracker = WordCompletionTracker("hello") + result = tracker.add_word_and_check_complete("helloworld") + self.assertTrue(result) + self.assertEqual(tracker.get_word_for_frame(), "hello") + self.assertEqual(tracker.get_overflow_word(), "world") + + def test_no_overflow_when_word_fits_exactly(self): + """A word that exactly fills remaining slots produces no overflow.""" + tracker = WordCompletionTracker("hello") + tracker.add_word_and_check_complete("hello") + self.assertIsNone(tracker.get_overflow_word()) + self.assertEqual(tracker.get_word_for_frame(), "hello") + + def test_overflow_split_preserves_non_alnum_suffix(self): + """The raw overflow word retains non-alnum chars after the split point.""" + tracker = WordCompletionTracker("1111") + # "1111And" — 4 alnum chars for this frame, "And" as raw overflow + result = tracker.add_word_and_check_complete("1111And") + self.assertTrue(result) + self.assertEqual(tracker.get_word_for_frame(), "1111") + self.assertEqual(tracker.get_overflow_word(), "And") + + def test_overflow_with_digits_splits_at_correct_position(self): + """Split position is computed by alnum count, not byte offset.""" + tracker = WordCompletionTracker("4111") # 4 alnum chars + tracker.add_word_and_check_complete("41111111") + self.assertEqual(tracker.get_word_for_frame(), "4111") + self.assertEqual(tracker.get_overflow_word(), "1111") + + def test_overflow_flows_into_next_tracker(self): + """Overflow word fed into the next tracker completes it correctly.""" + tracker1 = WordCompletionTracker("hello") + tracker2 = WordCompletionTracker("world") + + tracker1.add_word_and_check_complete("helloworld") + overflow = tracker1.get_overflow_word() + self.assertEqual(overflow, "world") + + result = tracker2.add_word_and_check_complete(overflow) + self.assertTrue(result) + self.assertEqual(tracker2.get_word_for_frame(), "world") + self.assertIsNone(tracker2.get_overflow_word()) + + def test_overflow_with_digits_flows_into_next_tracker(self): + """Realistic card-number overflow: last token spans two frame slots.""" + tracker1 = WordCompletionTracker("4111 1111 1111 1111") + tracker2 = WordCompletionTracker("And your") + + for word in ["4111", "1111", "1111"]: + tracker1.add_word_and_check_complete(word) + + # "1111And" straddles the frame boundary + result = tracker1.add_word_and_check_complete("1111And") + self.assertTrue(result) + self.assertEqual(tracker1.get_word_for_frame(), "1111") + self.assertEqual(tracker1.get_overflow_word(), "And") + + # Feed overflow into tracker2 + result = tracker2.add_word_and_check_complete(tracker1.get_overflow_word()) + self.assertFalse(result) + self.assertEqual(tracker2.get_word_for_frame(), "And") + + result = tracker2.add_word_and_check_complete("your") + self.assertTrue(result) + + def test_no_overflow_state_on_normal_word(self): + """After a normal (non-straddling) word, overflow getters return None.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + self.assertIsNone(tracker.get_overflow_word()) + + +class TestWordCompletionTrackerMissingWord(unittest.TestCase): + """Force-complete: TTS provider drops a word-timestamp event.""" + + def test_force_complete_when_word_does_not_belong(self): + """When word doesn't belong, the slot is force-completed and word becomes overflow.""" + tracker = WordCompletionTracker("number is") + result = tracker.add_word_and_check_complete("4111") + self.assertTrue(result) + self.assertEqual(tracker.get_overflow_word(), "4111") + + def test_force_complete_frame_word_is_full_remaining_expected(self): + """Force-complete with no prior progress: frame_word is the entire expected text.""" + tracker = WordCompletionTracker("number is") + tracker.add_word_and_check_complete("4111") + self.assertEqual(tracker.get_word_for_frame(), "number is") + + def test_force_complete_frame_word_is_partial_remaining_expected(self): + """Force-complete after partial progress: frame_word is only the unspoken suffix.""" + tracker = WordCompletionTracker("number is") + tracker.add_word_and_check_complete("number") # consumes "number" (6 chars) + # Now remaining = " is"; "4111" doesn't belong + result = tracker.add_word_and_check_complete("4111") + self.assertTrue(result) + self.assertEqual(tracker.get_word_for_frame(), "is") + self.assertEqual(tracker.get_overflow_word(), "4111") + + def test_force_complete_overflow_routes_to_next_tracker(self): + """Overflow from a force-completed slot feeds the next tracker correctly.""" + tracker1 = WordCompletionTracker("number is") + tracker2 = WordCompletionTracker("4111 1111") + + tracker1.add_word_and_check_complete("4111") # force-completes tracker1 + overflow = tracker1.get_overflow_word() + self.assertEqual(overflow, "4111") + + self.assertFalse(tracker2.add_word_and_check_complete(overflow)) + self.assertTrue(tracker2.add_word_and_check_complete("1111")) + + def test_force_complete_after_several_normal_words(self): + """Partial consumption followed by a mismatched word force-completes correctly.""" + tracker = WordCompletionTracker("Your credit card number is") + for word in ["Your", "credit", "card"]: + tracker.add_word_and_check_complete(word) + # "4111" doesn't belong (remaining starts with "number...") + result = tracker.add_word_and_check_complete("4111") + self.assertTrue(result) + # frame_word should be the unspoken suffix including leading space + self.assertEqual(tracker.get_word_for_frame(), "number is") + self.assertEqual(tracker.get_overflow_word(), "4111") + + def test_force_complete_marks_tracker_complete(self): + """After force-complete, is_complete is True regardless of how many chars were spoken.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + self.assertFalse(tracker.is_complete) + tracker.add_word_and_check_complete("4111") # force-complete + self.assertTrue(tracker.is_complete) + + def test_force_complete_does_not_affect_overflow_getters_on_normal_words(self): + """overflow state is fresh per call; a normal word clears any prior force-complete state.""" + tracker1 = WordCompletionTracker("ab") + tracker2 = WordCompletionTracker("cd") + + # Force-complete tracker1 with a wrong word + tracker1.add_word_and_check_complete("xyz") + self.assertIsNotNone(tracker1.get_overflow_word()) + + # tracker2 receives "cd" normally — no overflow + tracker2.add_word_and_check_complete("cd") + self.assertIsNone(tracker2.get_overflow_word()) + + +class TestWordCompletionTrackerLLMText(unittest.TestCase): + """llm_text cursor tracking: each word maps back to its span in the original LLM text.""" + + def test_llm_text_none_when_no_llm_text_provided(self): + """get_raw_consumed returns None when llm_text was not given.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + self.assertIsNone(tracker.get_llm_consumed()) + + def test_llm_text_single_word_no_tags(self): + """Simple case: llm_text matches tts_text, no tags.""" + tracker = WordCompletionTracker("hello world", llm_text="hello world") + tracker.add_word_and_check_complete("hello") + self.assertEqual(tracker.get_llm_consumed(), "hello") + tracker.add_word_and_check_complete("world") + self.assertEqual(tracker.get_llm_consumed(), "world") + + def test_llm_text_opening_tag_included_in_first_word(self): + """The opening tag preceding content is consumed with the first word.""" + raw = "4111 1111 1111 1111" + tracker = WordCompletionTracker("4111 1111 1111 1111", llm_text=raw) + tracker.add_word_and_check_complete("4111") + self.assertEqual(tracker.get_llm_consumed(), "4111") + + def test_llm_text_tag_chars_not_counted_as_alnum(self): + """Tag chars (c,a,r,d) inside must not burn the alnum budget.""" + # If tag chars were counted, "4111" would only consume "4111" should be consumed (last word → consume all). + self.assertEqual(tracker.get_llm_consumed(), "4111") + + def test_llm_text_four_words_with_card_tags(self): + """Full card-number scenario: each word maps to its correct raw span.""" + raw = "4111 1111 1111 1111" + tracker = WordCompletionTracker("4111 1111 1111 1111", llm_text=raw) + + tracker.add_word_and_check_complete("4111") + self.assertEqual(tracker.get_llm_consumed(), "4111") + + tracker.add_word_and_check_complete("1111") + self.assertEqual(tracker.get_llm_consumed(), "1111") + + tracker.add_word_and_check_complete("1111") + self.assertEqual(tracker.get_llm_consumed(), "1111") + + tracker.add_word_and_check_complete("1111") + # Last word: consume all remaining raw text including closing tag. + self.assertEqual(tracker.get_llm_consumed(), "1111") + self.assertTrue(tracker.is_complete) + + def test_llm_text_closing_tag_consumed_with_last_word(self): + """Closing tag is swept into the last word's raw_consumed, not lost.""" + raw = "hello" + tracker = WordCompletionTracker("hello", llm_text=raw) + tracker.add_word_and_check_complete("hello") + self.assertEqual(tracker.get_llm_consumed(), "hello") + + def test_llm_text_mid_frame_word_does_not_consume_closing_tag(self): + """Non-final words stop before the closing tag; only the last sweeps it up.""" + raw = "hello world" + tracker = WordCompletionTracker("hello world", llm_text=raw) + tracker.add_word_and_check_complete("hello") + self.assertEqual(tracker.get_llm_consumed(), "hello") + tracker.add_word_and_check_complete("world") + self.assertEqual(tracker.get_llm_consumed(), "world") + + def test_llm_text_force_complete_consumes_all_remaining(self): + """When force-complete fires, all remaining llm_text is consumed at once.""" + raw = "4111 1111 1111 1111" + tracker = WordCompletionTracker("4111 1111 1111 1111", llm_text=raw) + tracker.add_word_and_check_complete("4111") # advances cursor to pos 10 + # "WRONG" doesn't belong → force-complete + tracker.add_word_and_check_complete("WRONG") + self.assertEqual(tracker.get_llm_consumed(), "1111 1111 1111") + self.assertTrue(tracker.is_complete) + + def test_llm_text_overflow_word_last_raw_consumed_sweeps_tag(self): + """When a straddling word completes the frame, closing tag is consumed with it.""" + raw = "4111 1111" + tracker = WordCompletionTracker("4111 1111", llm_text=raw) + tracker.add_word_and_check_complete("4111") + self.assertEqual(tracker.get_llm_consumed(), "4111") + + # "1111And" straddles into the next frame + result = tracker.add_word_and_check_complete("1111And") + self.assertTrue(result) + self.assertEqual(tracker.get_word_for_frame(), "1111") + self.assertEqual(tracker.get_overflow_word(), "And") + # is_complete → consume all remaining raw including + self.assertEqual(tracker.get_llm_consumed(), "1111") + + def test_llm_text_with_ssml_tags_in_expected(self): + """expected_text with SSML tags: raw cursor still advances by content alnum count.""" + # Cartesia might receive "4111 1111" as expected_text + # while 4111 1111" + raw = "4111 1111" + tracker = WordCompletionTracker("4111 1111", llm_text=raw) + # normalized expected = "41111111" (8 chars); both texts share the same alnum count. + tracker.add_word_and_check_complete("4111") + self.assertEqual(tracker.get_llm_consumed(), "4111") + tracker.add_word_and_check_complete("1111") + self.assertEqual(tracker.get_llm_consumed(), "1111") + self.assertTrue(tracker.is_complete) + + +class TestWordCompletionTrackerMultiFrameSimulation(unittest.TestCase): + """End-to-end simulations of multiple AggregatedTextFrame slots.""" + + def test_two_plain_frames_sequential_words(self): + """Normal two-frame flow: all words arrive in order with no drops or overflow.""" + tracker1 = WordCompletionTracker("Your credit card") + tracker2 = WordCompletionTracker("number is 42") + + for word in ["Your", "credit", "card"]: + tracker1.add_word_and_check_complete(word) + self.assertTrue(tracker1.is_complete) + + for word in ["number", "is"]: + self.assertFalse(tracker2.add_word_and_check_complete(word)) + self.assertTrue(tracker2.add_word_and_check_complete("42")) + + def test_credit_card_full_flow_with_llm_text(self): + """Full card-number scenario with llm_text tracking across two frames.""" + tracker1 = WordCompletionTracker("Your credit card number is") + tracker2 = WordCompletionTracker( + "4111 1111 1111 1111", + llm_text="4111 1111 1111 1111", + ) + + for word in ["Your", "credit", "card", "number", "is"]: + tracker1.add_word_and_check_complete(word) + self.assertTrue(tracker1.is_complete) + + tracker2.add_word_and_check_complete("4111") + self.assertEqual(tracker2.get_llm_consumed(), "4111") + + tracker2.add_word_and_check_complete("1111") + self.assertEqual(tracker2.get_llm_consumed(), "1111") + + tracker2.add_word_and_check_complete("1111") + self.assertEqual(tracker2.get_llm_consumed(), "1111") + + result = tracker2.add_word_and_check_complete("1111") + self.assertTrue(result) + self.assertEqual(tracker2.get_llm_consumed(), "1111") + + def test_missing_word_force_complete_then_next_frame(self): + """Dropped word-timestamp: force-complete slot 1, route word to slot 2.""" + tracker1 = WordCompletionTracker("Your credit card number is") + tracker2 = WordCompletionTracker( + "4111 1111 1111 1111", + llm_text="4111 1111 1111 1111", + ) + + # "Your", "credit", "card" arrive; then "number" and "is" are dropped. + for word in ["Your", "credit", "card"]: + tracker1.add_word_and_check_complete(word) + + # "4111" arrives but belongs to tracker2 — force-completes tracker1. + result = tracker1.add_word_and_check_complete("4111") + self.assertTrue(result) + # frame_word carries the unspoken remainder so a TTSTextFrame can be emitted. + self.assertEqual(tracker1.get_word_for_frame(), "number is") + overflow = tracker1.get_overflow_word() + self.assertEqual(overflow, "4111") + + # Route overflow into tracker2. + tracker2.add_word_and_check_complete(overflow) + self.assertEqual(tracker2.get_llm_consumed(), "4111") + + tracker2.add_word_and_check_complete("1111") + tracker2.add_word_and_check_complete("1111") + result = tracker2.add_word_and_check_complete("1111") + self.assertTrue(result) + self.assertEqual(tracker2.get_llm_consumed(), "1111") + + def test_overflow_word_spans_two_frames(self): + """A word token that straddles two frame boundaries splits and routes correctly.""" + tracker1 = WordCompletionTracker("4111 1111 1111 1111") + tracker2 = WordCompletionTracker("And your") + + for word in ["4111", "1111", "1111"]: + tracker1.add_word_and_check_complete(word) + + # "1111And" spans the frame boundary. + result = tracker1.add_word_and_check_complete("1111And") + self.assertTrue(result) + self.assertEqual(tracker1.get_word_for_frame(), "1111") + self.assertEqual(tracker1.get_overflow_word(), "And") + + # Feed overflow into tracker2. + result = tracker2.add_word_and_check_complete(tracker1.get_overflow_word()) + self.assertFalse(result) + self.assertEqual(tracker2.get_word_for_frame(), "And") + + self.assertTrue(tracker2.add_word_and_check_complete("your")) + + def test_overflow_with_llm_text_across_frames(self): + """Raw text cursor is correct when the last straddling word completes the frame.""" + tracker1 = WordCompletionTracker( + "4111 1111", + llm_text="4111 1111", + ) + tracker2 = WordCompletionTracker("And") + + tracker1.add_word_and_check_complete("4111") + self.assertEqual(tracker1.get_llm_consumed(), "4111") + + # "1111And" completes tracker1; "And" overflows to tracker2. + result = tracker1.add_word_and_check_complete("1111And") + self.assertTrue(result) + self.assertEqual(tracker1.get_llm_consumed(), "1111") + self.assertEqual(tracker1.get_overflow_word(), "And") + + result = tracker2.add_word_and_check_complete(tracker1.get_overflow_word()) + self.assertTrue(result) + self.assertEqual(tracker2.get_word_for_frame(), "And") + + def test_multiple_missing_words_single_force_complete(self): + """Even if several consecutive words are dropped, one force-complete handles them all.""" + # Frame expects "one two three"; TTS skips straight to "four" (next frame's word). + tracker1 = WordCompletionTracker("one two three") + tracker2 = WordCompletionTracker("four five") + + # No words from tracker1 ever arrive; "four" is the first word seen. + result = tracker1.add_word_and_check_complete("four") + self.assertTrue(result) + self.assertEqual(tracker1.get_word_for_frame(), "one two three") + self.assertEqual(tracker1.get_overflow_word(), "four") + + self.assertFalse(tracker2.add_word_and_check_complete("four")) + self.assertTrue(tracker2.add_word_and_check_complete("five")) + + +class TestWordCompletionTrackerEmojiInSentence(unittest.TestCase): + """Word-by-word validation of frame_word and raw_consumed for sentences + containing emojis and special characters. + + Covers: + - Emoji in the middle of a sentence whose llm_text has surrounding XML tags. + - Emoji adjacent to currency symbols and punctuation. + - Multiple emojis scattered through a sentence with no XML tags. + - Emoji that does NOT appear in llm_text (safeguard must return None). + """ + + def _assert_word(self, tracker, word, expected_frame_word, expected_raw_consumed, idx): + msg = f"word {idx + 1} {repr(word)}" + self.assertEqual(tracker.get_word_for_frame(), expected_frame_word, f"{msg}: frame_word") + self.assertEqual(tracker.get_llm_consumed(), expected_raw_consumed, f"{msg}: raw_consumed") + + def test_emoji_in_middle_with_raw_tags(self): + """Sentence: 'Great job! 🎉 Well done.' wrapped in tags. + + The emoji word gets its own raw_consumed span ('🎉') because the + chars_for_frame==0 branch finds it at the correct offset in llm_text. + Words that follow the emoji pick up immediately after it. + """ + sentence = "Great job! 🎉 Well done." + llm_text = f"{sentence}" + words = ["Great", "job!", "🎉", "Well", "done."] + + expected_frame_words = ["Great", "job!", "🎉", "Well", "done."] + expected_raw_consumed = [ + "Great", # opening tag consumed with first alnum word + "job!", # " job!" stripped + "🎉", # emoji found directly in llm_text + "Well", # " Well" stripped; emoji already consumed + "done.", # last word sweeps closing tag + ] + + tracker = WordCompletionTracker(sentence, llm_text=llm_text) + for i, word in enumerate(words): + is_complete = tracker.add_word_and_check_complete(word) + self._assert_word(tracker, word, expected_frame_words[i], expected_raw_consumed[i], i) + if i == len(words) - 1: + self.assertTrue(is_complete, f"should be complete after final word {repr(word)}") + else: + self.assertFalse(is_complete, f"should not be complete after {repr(word)}") + + def test_emoji_with_currency_and_raw_tags(self): + """Sentence: 'Pay $50 😊 today!' wrapped in tags. + + Validates that currency symbols ($) are treated as non-alnum punctuation + and that the emoji token is correctly assigned its own raw span. + """ + sentence = "Pay $50 😊 today!" + llm_text = f"{sentence}" + words = ["Pay", "$50", "😊", "today!"] + + expected_frame_words = ["Pay", "$50", "😊", "today!"] + expected_raw_consumed = [ + "Pay", # opening tag consumed with first word + "$50", # " $50" stripped + "😊", # emoji found directly in llm_text + "today!", # last word sweeps closing tag + ] + + tracker = WordCompletionTracker(sentence, llm_text=llm_text) + for i, word in enumerate(words): + is_complete = tracker.add_word_and_check_complete(word) + self._assert_word(tracker, word, expected_frame_words[i], expected_raw_consumed[i], i) + if i == len(words) - 1: + self.assertTrue(is_complete, f"should be complete after final word {repr(word)}") + else: + self.assertFalse(is_complete, f"should not be complete after {repr(word)}") + + def test_multiple_emojis_no_tags(self): + """Sentence: 'Hello 😊 world 🎉 there!' with llm_text equal to the sentence. + + Two emojis at different positions in the sentence; each gets its own + raw_consumed span and neither disrupts the alnum cursor for the words + that follow. + """ + sentence = "Hello 😊 world 🎉 there!" + words = ["Hello", "😊", "world", "🎉", "there!"] + + expected_frame_words = ["Hello", "😊", "world", "🎉", "there!"] + expected_raw_consumed = [ + "Hello", # no leading tag + "😊", # emoji found directly + "world", # " world" stripped + "🎉", # second emoji found directly + "there!", # " there!" stripped; last word + ] + + tracker = WordCompletionTracker(sentence, llm_text=sentence) + for i, word in enumerate(words): + is_complete = tracker.add_word_and_check_complete(word) + self._assert_word(tracker, word, expected_frame_words[i], expected_raw_consumed[i], i) + if i == len(words) - 1: + self.assertTrue(is_complete, f"should be complete after final word {repr(word)}") + else: + self.assertFalse(is_complete, f"should not be complete after {repr(word)}") + + def test_emoji_absent_from_llm_text_returns_none(self): + """Sentence: 'See you soon 😊' but llm_text omits the emoji. + + The emoji is the last token the TTS returns for this frame. The alnum + content ('seeyousoon') is already complete after 'soon', so llm_text is + exhausted. The safeguard must set raw_consumed to None because the swept + span is empty and does not contain '😊'. + """ + sentence = "See you soon 😊" + llm_text = "See you soon" + words = ["See", "you", "soon", "😊"] + + expected_frame_words = ["See", "you", "soon", "😊"] + expected_raw_consumed = [ + "See", # opening tag consumed with first word + "you", # " you" stripped + "soon", # last alnum word sweeps closing tag + None, # safeguard: emoji not in exhausted llm_text + ] + + tracker = WordCompletionTracker(sentence, llm_text=llm_text) + for i, word in enumerate(words): + tracker.add_word_and_check_complete(word) + self._assert_word(tracker, word, expected_frame_words[i], expected_raw_consumed[i], i) + + self.assertTrue(tracker.is_complete) + + +class TestWordCompletionTrackerRemainingText(unittest.TestCase): + """Tests for get_remaining_tts_text() and get_remaining_llm_text(). + + These methods expose the unspoken suffix of tts_text / llm_text so that + _force_complete_spoken_slots() can emit a TTSTextFrame for any words that the + TTS provider dropped without sending a word-timestamp event. + """ + + # ------------------------------------------------------------------ # + # get_remaining_text # + # ------------------------------------------------------------------ # + + def test_remaining_text_before_any_words(self): + """Before any words arrive, the full expected text is remaining.""" + tracker = WordCompletionTracker("hello world") + self.assertEqual(tracker.get_remaining_tts_text(), "hello world") + + def test_remaining_text_after_partial_consumption(self): + """After one word is consumed, only the unspoken suffix is returned.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + self.assertEqual(tracker.get_remaining_tts_text(), "world") + + def test_remaining_text_after_completion(self): + """After full completion get_remaining_text returns an empty string.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + tracker.add_word_and_check_complete("world") + self.assertEqual(tracker.get_remaining_tts_text(), "") + + def test_remaining_text_strips_whitespace(self): + """Leading/trailing whitespace in the remaining portion is stripped.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + # The tts_text suffix before stripping is " world". + result = tracker.get_remaining_tts_text() + self.assertEqual(result, "world") + self.assertFalse(result.startswith(" ")) + + def test_remaining_text_preserves_punctuation(self): + """Punctuation inside the remaining text is not removed.""" + tracker = WordCompletionTracker("Hello, world!") + tracker.add_word_and_check_complete("Hello,") + self.assertEqual(tracker.get_remaining_tts_text(), "world!") + + def test_remaining_text_after_force_complete(self): + """Force-completing the slot exhausts expected_raw_pos; remaining becomes empty.""" + tracker = WordCompletionTracker("number is") + tracker.add_word_and_check_complete("4111") # force-complete + self.assertTrue(tracker.is_complete) + self.assertEqual(tracker.get_remaining_tts_text(), "number is") + + def test_remaining_text_resets_after_reset(self): + """reset() resets expected_raw_pos; get_remaining_text returns the full text again.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + tracker.add_word_and_check_complete("world") + self.assertEqual(tracker.get_remaining_tts_text(), "") + tracker.reset() + self.assertEqual(tracker.get_remaining_tts_text(), "hello world") + + # ------------------------------------------------------------------ # + # get_remaining_llm_text # + # ------------------------------------------------------------------ # + + def test_remaining_llm_text_none_when_no_llm_text(self): + """Without llm_text, get_remaining_llm_text always returns None.""" + tracker = WordCompletionTracker("hello world") + self.assertIsNone(tracker.get_remaining_llm_text()) + tracker.add_word_and_check_complete("hello") + self.assertIsNone(tracker.get_remaining_llm_text()) + + def test_remaining_llm_text_before_any_words(self): + """Before any words, the entire llm_text is remaining.""" + raw = "4111 1111" + tracker = WordCompletionTracker("4111 1111", llm_text=raw) + self.assertEqual(tracker.get_remaining_llm_text(), raw) + + def test_remaining_llm_text_after_partial_consumption(self): + """After one word, the raw cursor is past its span; only the tail is remaining.""" + raw = "4111 1111" + tracker = WordCompletionTracker("4111 1111", llm_text=raw) + tracker.add_word_and_check_complete("4111") + # raw_pos moves past "4111" (10 chars); remaining = " 1111" → strip → "1111" + self.assertEqual(tracker.get_remaining_llm_text(), "1111") + + def test_remaining_llm_text_none_after_completion(self): + """After full completion all llm_text is consumed; remaining is None.""" + raw = "4111 1111" + tracker = WordCompletionTracker("4111 1111", llm_text=raw) + tracker.add_word_and_check_complete("4111") + tracker.add_word_and_check_complete("1111") + self.assertIsNone(tracker.get_remaining_llm_text()) + + def test_remaining_llm_text_resets_after_reset(self): + """reset() resets the raw cursor; get_remaining_llm_text returns the full llm_text again.""" + raw = "hello world" + tracker = WordCompletionTracker("hello world", llm_text=raw) + tracker.add_word_and_check_complete("hello") + tracker.add_word_and_check_complete("world") + self.assertIsNone(tracker.get_remaining_llm_text()) + tracker.reset() + self.assertEqual(tracker.get_remaining_llm_text(), raw) + + def test_remaining_llm_text_strips_whitespace(self): + """Leading/trailing whitespace in the remaining raw portion is stripped.""" + raw = "hello world" + tracker = WordCompletionTracker("hello world", llm_text=raw) + tracker.add_word_and_check_complete("hello") + # raw_pos is past "hello" (11 chars); raw suffix = " world" → strip + result = tracker.get_remaining_llm_text() + self.assertIsNotNone(result) + self.assertFalse(result.startswith(" ")) + self.assertEqual(result, "world") + + def test_remaining_llm_text_after_force_complete(self): + """Force-complete sweeps all llm_text; remaining is None afterwards.""" + raw = "4111 1111 1111 1111" + tracker = WordCompletionTracker("4111 1111 1111 1111", llm_text=raw) + tracker.add_word_and_check_complete("4111") + tracker.add_word_and_check_complete("WRONG") # force-complete + self.assertTrue(tracker.is_complete) + self.assertIsNone(tracker.get_remaining_llm_text()) + + +class TestWordCompletionTrackerAccumulatedText(unittest.TestCase): + """Tests for get_accumulated_tts_text() and get_accumulated_llm_text(). + + These methods return the prefix of tts_text / llm_text that has already been + consumed by word-timestamp events — the complement of get_remaining_tts_text() + and get_remaining_llm_text() which return the unspoken suffix. + """ + + # ------------------------------------------------------------------ # + # get_accumulated_tts_text # + # ------------------------------------------------------------------ # + + def test_accumulated_tts_text_before_any_words(self): + """Before any words arrive, nothing has been consumed.""" + tracker = WordCompletionTracker("hello world") + self.assertEqual(tracker.get_accumulated_tts_text(), "") + + def test_accumulated_tts_text_after_partial_consumption(self): + """After one word the accumulated prefix covers exactly that word.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + self.assertEqual(tracker.get_accumulated_tts_text(), "hello") + + def test_accumulated_tts_text_after_completion(self): + """After all words, the entire tts_text is accumulated.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + tracker.add_word_and_check_complete("world") + self.assertEqual(tracker.get_accumulated_tts_text(), "hello world") + + def test_accumulated_tts_text_includes_trailing_punctuation(self): + """The TTS cursor consumes trailing punctuation, so it appears in the accumulated span.""" + tracker = WordCompletionTracker("Hello, world!") + tracker.add_word_and_check_complete("Hello,") + self.assertEqual(tracker.get_accumulated_tts_text(), "Hello,") + + def test_accumulated_tts_text_complements_remaining(self): + """accumulated + stripped-whitespace + remaining reconstructs tts_text.""" + tts_text = "hello world" + tracker = WordCompletionTracker(tts_text) + tracker.add_word_and_check_complete("hello") + accumulated = tracker.get_accumulated_tts_text() + remaining = tracker.get_remaining_tts_text() + # The space between tokens is stripped from remaining, so re-join with one. + self.assertEqual(accumulated + " " + remaining, tts_text) + + def test_accumulated_tts_text_after_force_complete(self): + """Force-complete does not advance the TTS cursor; accumulated stays at the + point reached before the mismatched word arrived.""" + tracker = WordCompletionTracker("number is") + tracker.add_word_and_check_complete("number") + tracker.add_word_and_check_complete("4111") # force-complete + self.assertEqual(tracker.get_accumulated_tts_text(), "number") + + def test_accumulated_tts_text_resets_after_reset(self): + """reset() resets the TTS cursor; accumulated returns empty string again.""" + tracker = WordCompletionTracker("hello world") + tracker.add_word_and_check_complete("hello") + tracker.add_word_and_check_complete("world") + self.assertEqual(tracker.get_accumulated_tts_text(), "hello world") + tracker.reset() + self.assertEqual(tracker.get_accumulated_tts_text(), "") + + # ------------------------------------------------------------------ # + # get_accumulated_llm_text # + # ------------------------------------------------------------------ # + + def test_accumulated_llm_text_none_when_no_llm_text(self): + """Without llm_text, get_accumulated_llm_text always returns None.""" + tracker = WordCompletionTracker("hello world") + self.assertIsNone(tracker.get_accumulated_llm_text()) + tracker.add_word_and_check_complete("hello") + self.assertIsNone(tracker.get_accumulated_llm_text()) + + def test_accumulated_llm_text_before_any_words(self): + """Before any words, nothing has been consumed from llm_text.""" + tracker = WordCompletionTracker("4111 1111", llm_text="4111 1111") + self.assertEqual(tracker.get_accumulated_llm_text(), "") + + def test_accumulated_llm_text_after_partial_consumption(self): + """After one word, the consumed prefix includes the opening tag and first word.""" + tracker = WordCompletionTracker("4111 1111", llm_text="4111 1111") + tracker.add_word_and_check_complete("4111") + self.assertEqual(tracker.get_accumulated_llm_text(), "4111") + + def test_accumulated_llm_text_after_completion(self): + """After all words, the entire llm_text is accumulated (closing tag included).""" + tracker = WordCompletionTracker("4111 1111", llm_text="4111 1111") + tracker.add_word_and_check_complete("4111") + tracker.add_word_and_check_complete("1111") + self.assertEqual(tracker.get_accumulated_llm_text(), "4111 1111") + + def test_accumulated_llm_text_complements_remaining(self): + """accumulated + stripped-whitespace + remaining reconstructs llm_text.""" + llm = "4111 1111" + tracker = WordCompletionTracker("4111 1111", llm_text=llm) + tracker.add_word_and_check_complete("4111") + accumulated = tracker.get_accumulated_llm_text() + remaining = tracker.get_remaining_llm_text() + self.assertEqual(accumulated.rstrip() + " " + remaining, llm) + + def test_accumulated_llm_text_after_force_complete(self): + """Force-complete sweeps all remaining llm_text; accumulated covers the full string.""" + llm = "4111 1111 1111 1111" + tracker = WordCompletionTracker("4111 1111 1111 1111", llm_text=llm) + tracker.add_word_and_check_complete("4111") + tracker.add_word_and_check_complete("WRONG") # force-complete sweeps remaining llm_text + self.assertEqual(tracker.get_accumulated_llm_text(), llm) + + def test_accumulated_llm_text_resets_after_reset(self): + """reset() resets the LLM cursor; accumulated returns empty string again.""" + tracker = WordCompletionTracker("hello world", llm_text="hello world") + tracker.add_word_and_check_complete("hello") + tracker.add_word_and_check_complete("world") + self.assertEqual(tracker.get_accumulated_llm_text(), "hello world") + tracker.reset() + self.assertEqual(tracker.get_accumulated_llm_text(), "") + + +class TestWordCompletionTrackerUnicodeSymbolSubstitution(unittest.TestCase): + """Guards against the regression where ElevenLabs maps Unicode symbols such + as '→' to ASCII punctuation like '-' in word-timestamp events. + + The literal-substring check in _symbol_word_belongs_here failed to find '-' + inside '→ Santiago…', which caused premature force-completion of the whole + frame after 'Paulo' was consumed. The symbol-substitution fallback (check + whether the next non-space char in the TTS text is itself a non-alnum symbol) + must accept the substituted '-' and keep the tracker running normally. + """ + + # The exact sentence that revealed the bug. + SENTENCE = "- Example route: São Paulo → Santiago (Chile) → Auckland (New Zealand)." + + # Words as ElevenLabs emits them: both '→' chars are reported as '-'. + ELEVENLABS_WORDS = [ + "-", + "Example", + "route:", + "São", + "Paulo", + "-", # ElevenLabs substitution for first → + "Santiago", + "(Chile)", + "-", # ElevenLabs substitution for second → + "Auckland", + "(New", + "Zealand).", + ] + + def test_arrow_as_dash_belongs_here_before_first_arrow(self): + """After consuming up to 'Paulo', word_belongs_here('-') must return True. + + The next character in the TTS text is '→'. ElevenLabs sends '-' instead; + the symbol-substitution fallback must accept it without force-completing. + """ + tracker = WordCompletionTracker(self.SENTENCE) + for word in ["-", "Example", "route:", "São", "Paulo"]: + tracker.add_word_and_check_complete(word) + self.assertTrue(tracker.word_belongs_here("-")) + + def test_first_arrow_substitution_does_not_force_complete(self): + """Processing '-' in place of the first '→' must not complete the tracker. + + Without the fix, word_belongs_here('-') returned False here and the + force-complete path fired, marking the entire frame done prematurely. + """ + tracker = WordCompletionTracker(self.SENTENCE) + for word in ["-", "Example", "route:", "São", "Paulo"]: + tracker.add_word_and_check_complete(word) + result = tracker.add_word_and_check_complete("-") + self.assertFalse(result, "tracker must not be complete after the first → substitution") + self.assertFalse(tracker.is_complete) + + def test_arrow_as_dash_belongs_here_before_second_arrow(self): + """After consuming through '(Chile)', word_belongs_here('-') must return True again.""" + tracker = WordCompletionTracker(self.SENTENCE) + for word in ["-", "Example", "route:", "São", "Paulo", "-", "Santiago", "(Chile)"]: + tracker.add_word_and_check_complete(word) + self.assertTrue(tracker.word_belongs_here("-")) + + def test_completes_after_all_elevenlabs_words(self): + """Feeding the full ElevenLabs word-timestamp stream must complete the tracker.""" + tracker = WordCompletionTracker(self.SENTENCE) + results = [tracker.add_word_and_check_complete(w) for w in self.ELEVENLABS_WORDS] + self.assertTrue(results[-1], "tracker should be complete after the last word") + self.assertTrue(tracker.is_complete) + + def test_only_last_word_triggers_completion(self): + """No intermediate word should complete the tracker.""" + tracker = WordCompletionTracker(self.SENTENCE) + for word in self.ELEVENLABS_WORDS[:-1]: + result = tracker.add_word_and_check_complete(word) + self.assertFalse(result, f"should not be complete after '{word}'") + self.assertTrue(tracker.add_word_and_check_complete(self.ELEVENLABS_WORDS[-1])) + + +class TestWordCompletionTrackerCJK(unittest.TestCase): + """Completion tracking for CJK scripts: Korean, Japanese, and Chinese. + + Korean words are fed one at a time (Cartesia returns each Hangul word as a + separate timestamp event). Japanese and Chinese characters are fed as + combined groups (Cartesia merges all chars in one timestamp message into a + single token before calling add_word_timestamps). + """ + + # --- Korean --- + + def test_korean_word_by_word_completion(self): + """Korean words fed one at a time complete the tracker correctly.""" + sentence = "저는 여러분의 AI 어시스턴트입니다." + words = ["저는", "여러분의", "AI", "어시스턴트입니다."] + tracker = WordCompletionTracker(sentence) + for word in words[:-1]: + self.assertFalse(tracker.add_word_and_check_complete(word)) + self.assertTrue(tracker.add_word_and_check_complete(words[-1])) + + def test_korean_normalized_char_count_matches_raw_alnum(self): + """Each Hangul syllable must normalize to exactly one char. + + The NFKD decomposition would expand each syllable into 2-3 conjoining + jamo, making the normalized length much larger than the raw alnum count + and breaking the _advance_by_alnums cursor. NFD-per-char normalization + must keep each syllable as a single character. + """ + samples = ["저는여러분의", "안녕하세요", "어시스턴트"] + for text in samples: + raw_count = sum(1 for c in text if c.isalnum()) + norm_count = len(WordCompletionTracker._normalize(text)) + self.assertEqual( + norm_count, + raw_count, + f"_normalize({text!r}): got {norm_count} chars, want {raw_count}", + ) + + def test_korean_force_complete_remaining_text_is_correct(self): + """After the first Korean word the TTS cursor is at the right position. + + get_remaining_tts_text() must return the unspoken suffix verbatim so + force_complete can emit a correct TTSTextFrame for it. + """ + sentence = "저는 여러분의 AI 어시스턴트입니다." + tracker = WordCompletionTracker(sentence) + tracker.add_word_and_check_complete("저는") + self.assertEqual(tracker.get_remaining_tts_text(), "여러분의 AI 어시스턴트입니다.") + + def test_korean_mixed_with_latin(self): + """Latin tokens (e.g. 'AI') embedded in Korean text are handled correctly.""" + sentence = "AI 어시스턴트입니다." + tracker = WordCompletionTracker(sentence) + self.assertFalse(tracker.add_word_and_check_complete("AI")) + self.assertTrue(tracker.add_word_and_check_complete("어시스턴트입니다.")) + + def test_korean_word_belongs_here(self): + """word_belongs_here distinguishes Korean words based on remaining content.""" + tracker = WordCompletionTracker("저는 여러분의") + self.assertTrue(tracker.word_belongs_here("저는")) + # "여러분의" starts with chars that follow "저는", so before consuming + # "저는" the next word doesn't belong here yet. + self.assertFalse(tracker.word_belongs_here("여러분의")) + + tracker.add_word_and_check_complete("저는") + self.assertTrue(tracker.word_belongs_here("여러분의")) + + # --- Japanese --- + + def test_japanese_single_combined_group_completes(self): + """A single Cartesia-style combined group (all chars merged) completes the frame.""" + # Cartesia merges ["こ","ん","に","ち","は","。"] into "こんにちは。" + tracker = WordCompletionTracker("こんにちは。") + self.assertTrue(tracker.add_word_and_check_complete("こんにちは。")) + + def test_japanese_multiple_groups_for_one_frame(self): + """Two Cartesia timestamp groups for one Japanese frame complete in sequence.""" + sentence = "こんにちは、私はあなたの" + tracker = WordCompletionTracker(sentence) + self.assertFalse(tracker.add_word_and_check_complete("こんにちは、私")) + self.assertTrue(tracker.add_word_and_check_complete("はあなたの")) + + def test_japanese_force_complete_remaining_text(self): + """After the first Japanese group the cursor sits at the right position.""" + sentence = "こんにちは、私はあなたの" + tracker = WordCompletionTracker(sentence) + tracker.add_word_and_check_complete("こんにちは、私") + self.assertEqual(tracker.get_remaining_tts_text(), "はあなたの") + + def test_japanese_punctuation_not_counted_toward_completion(self): + """Japanese punctuation (。、) is not alphanumeric and must not block completion.""" + tracker = WordCompletionTracker("こんにちは。") + # "こんにちは" covers all 5 alnum chars; "。" adds 0. + self.assertTrue(tracker.add_word_and_check_complete("こんにちは。")) + + # --- Chinese --- + + def test_chinese_single_combined_group_completes(self): + """A single Cartesia-style combined Chinese group completes the frame.""" + # Cartesia merges ["你","好",",","我","是"] into "你好,我是" + tracker = WordCompletionTracker("你好,我是") + self.assertTrue(tracker.add_word_and_check_complete("你好,我是")) + + def test_chinese_multiple_groups_for_one_frame(self): + """Two Cartesia timestamp groups for one Chinese frame complete in sequence.""" + sentence = "你好,我是你的智能" + tracker = WordCompletionTracker(sentence) + self.assertFalse(tracker.add_word_and_check_complete("你好,我是")) + self.assertTrue(tracker.add_word_and_check_complete("你的智能")) + + def test_chinese_force_complete_remaining_text(self): + """After the first Chinese group the cursor sits at the right position.""" + sentence = "你好,我是你的智能" + tracker = WordCompletionTracker(sentence) + tracker.add_word_and_check_complete("你好,我是") + self.assertEqual(tracker.get_remaining_tts_text(), "你的智能") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_word_timestamp_utils.py b/tests/test_word_timestamp_utils.py new file mode 100644 index 000000000..9c330d286 --- /dev/null +++ b/tests/test_word_timestamp_utils.py @@ -0,0 +1,90 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.utils.text.word_timestamp_utils import merge_punct_tokens + + +class TestMergePunctTokens(unittest.TestCase): + def test_empty_list(self): + self.assertEqual(merge_punct_tokens([]), []) + + def test_all_alnum_words_pass_through(self): + input = [("hello", 0.0), ("world", 1.0)] + self.assertEqual(merge_punct_tokens(input), [("hello", 0.0), ("world", 1.0)]) + + def test_trailing_space_merged_and_stripped(self): + input = [("I", 0.0), (" ", 0.2)] + self.assertEqual(merge_punct_tokens(input), [("I", 0.0)]) + + def test_comma_space_merged_and_stripped(self): + input = [("questions", 1.0), (", ", 1.2), ("explain", 1.4)] + self.assertEqual(merge_punct_tokens(input), [("questions,", 1.0), ("explain", 1.4)]) + + def test_leading_space_with_no_preceding_word_discarded(self): + input = [(" ", 0.0), ("hello", 0.5)] + self.assertEqual(merge_punct_tokens(input), [("hello", 0.5)]) + + def test_leading_empty_string_discarded(self): + input = [("", 0.0), ("hello", 0.5)] + self.assertEqual(merge_punct_tokens(input), [("hello", 0.5)]) + + def test_multiple_consecutive_punct_tokens_merged_and_stripped(self): + input = [("word", 0.0), (",", 0.1), (" ", 0.2), ("next", 0.3)] + self.assertEqual(merge_punct_tokens(input), [("word,", 0.0), ("next", 0.3)]) + + def test_timestamp_of_preceding_word_is_kept(self): + """Merged punct tokens adopt the preceding word's timestamp.""" + input = [("hello", 2.5), (",", 2.7)] + result = merge_punct_tokens(input) + self.assertEqual(result, [("hello,", 2.5)]) + + def test_xml_tag_only_token_is_treated_as_punct(self): + """A token that is only an XML tag (no alnum chars) merges into the preceding word.""" + input = [("word", 0.0), ("", 0.1), ("next", 0.3)] + self.assertEqual(merge_punct_tokens(input), [("word", 0.0), ("next", 0.3)]) + + def test_xml_tag_with_alnum_content_passes_through(self): + """A token like '123' has alnum chars after stripping tags.""" + input = [("123", 0.0), ("and", 0.5)] + self.assertEqual(merge_punct_tokens(input), [("123", 0.0), ("and", 0.5)]) + + def test_inworld_style_full_stream(self): + """Full Inworld-style raw stream produces expected merged and stripped output.""" + raw = [ + ("", 0.0), + ("I", 0.1), + (" ", 0.2), + ("can", 0.3), + (" ", 0.4), + ("answer", 0.5), + (" ", 0.6), + ("questions", 0.7), + (", ", 0.8), + ("explain", 0.9), + (" ", 1.0), + ("things", 1.1), + (".", 1.2), + ] + expected = [ + ("I", 0.1), + ("can", 0.3), + ("answer", 0.5), + ("questions,", 0.7), + ("explain", 0.9), + ("things.", 1.1), + ] + self.assertEqual(merge_punct_tokens(raw), expected) + + def test_only_punct_tokens_returns_empty(self): + """A list containing only punct/space tokens produces an empty result.""" + input = [(" ", 0.0), (",", 0.1), (".", 0.2)] + self.assertEqual(merge_punct_tokens(input), []) + + +if __name__ == "__main__": + unittest.main()