test: add automated tests for word tracking, frame sequencing, and Cartesia TTS

Adds tests for AggregatedFrameSequencer, WordCompletionTracker, and
word_timestamp_utils (including CJK language scenarios). Updates existing
Cartesia TTS and TTS frame ordering tests to cover the new behaviours.
This commit is contained in:
filipi87
2026-05-20 10:03:26 -03:00
parent e1bdee598c
commit 81bb81c1d0
5 changed files with 2767 additions and 27 deletions

View File

@@ -0,0 +1,612 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for AggregatedFrameSequencer.
All methods on the sequencer are synchronous and return lists of frames,
so no async machinery is needed here.
Test groups:
- register_skipped: immediate flush vs. blocked by a preceding spoken slot
- register_spoken / complete_spoken_slot: push_text_frames=True path
- flush: pts propagation, transport_destination, stops at incomplete spoken slot
- process_word: normal, completing, passthrough, raw_text propagation
- process_word overflow: single token spanning two slot boundaries
- process_word force-complete via belongs_here failure
- force_complete: remaining text emission, raw_text, corrupt raw discard, slot ordering
- clear: resets all state
"""
import unittest
from pipecat.frames.frames import AggregatedTextFrame, AggregationType, TTSTextFrame
from pipecat.utils.context.aggregated_frame_sequencer import AggregatedFrameSequencer
from pipecat.utils.context.word_completion_tracker import WordCompletionTracker
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _seq() -> AggregatedFrameSequencer:
return AggregatedFrameSequencer(name="test")
def _spoken_frame(text: str) -> AggregatedTextFrame:
return AggregatedTextFrame(text, AggregationType.SENTENCE)
def _skipped_frame(text: str) -> AggregatedTextFrame:
return AggregatedTextFrame(text, "code")
def _tracker(tts_text: str, llm_text: str | None = None) -> WordCompletionTracker:
return WordCompletionTracker(tts_text, llm_text=llm_text)
# ---------------------------------------------------------------------------
# register_skipped
# ---------------------------------------------------------------------------
class TestRegisterSkipped(unittest.TestCase):
def test_emits_immediately_with_empty_queue(self):
seq = _seq()
frame = _skipped_frame("code block")
result = seq.register_skipped(frame, "ctx1", None)
self.assertEqual(len(result), 1)
self.assertIs(result[0], frame)
def test_sets_append_to_context_true(self):
seq = _seq()
frame = _skipped_frame("code")
seq.register_skipped(frame, "ctx1", None)
self.assertTrue(frame.append_to_context)
def test_sets_context_id_on_frame(self):
seq = _seq()
frame = _skipped_frame("code")
seq.register_skipped(frame, "ctx42", None)
self.assertEqual(frame.context_id, "ctx42")
def test_sets_transport_destination(self):
seq = _seq()
frame = _skipped_frame("code")
result = seq.register_skipped(frame, "ctx1", "dest-A")
self.assertEqual(result[0].transport_destination, "dest-A")
def test_blocked_by_incomplete_spoken_slot(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
result = seq.register_skipped(_skipped_frame("code"), "ctx2", None)
self.assertEqual(result, [])
def test_emits_immediately_after_already_complete_spoken_slot(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hi"), "ctx1", tracker=None, append_to_context=True)
seq.complete_spoken_slot()
result = seq.register_skipped(_skipped_frame("code"), "ctx2", None)
self.assertEqual(len(result), 1)
def test_multiple_skipped_before_any_spoken_all_emit(self):
seq = _seq()
r1 = seq.register_skipped(_skipped_frame("code1"), "ctx1", None)
r2 = seq.register_skipped(_skipped_frame("code2"), "ctx2", None)
self.assertEqual(len(r1), 1)
self.assertEqual(len(r2), 1)
# ---------------------------------------------------------------------------
# register_spoken / complete_spoken_slot (push_text_frames=True path)
# ---------------------------------------------------------------------------
class TestCompleteSpokenSlot(unittest.TestCase):
def test_noop_with_empty_queue(self):
seq = _seq()
self.assertEqual(seq.complete_spoken_slot(), [])
def test_marks_slot_complete_and_flushes_skipped(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx2", None) # blocked
result = seq.complete_spoken_slot()
self.assertEqual(len(result), 1)
self.assertIs(result[0], skipped)
self.assertTrue(skipped.append_to_context)
def test_only_first_pending_slot_is_marked(self):
seq = _seq()
seq.register_spoken(_spoken_frame("one"), "ctx1", tracker=None, append_to_context=True)
seq.register_spoken(_spoken_frame("two"), "ctx2", tracker=None, append_to_context=True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx3", None)
# ctx2 still blocks the skipped frame
result = seq.complete_spoken_slot()
self.assertEqual(result, [])
def test_skipped_flushes_after_all_preceding_spoken_complete(self):
seq = _seq()
seq.register_spoken(_spoken_frame("one"), "ctx1", tracker=None, append_to_context=True)
seq.register_spoken(_spoken_frame("two"), "ctx2", tracker=None, append_to_context=True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx3", None)
seq.complete_spoken_slot() # completes ctx1
result = seq.complete_spoken_slot() # completes ctx2 → flush skipped
self.assertEqual(len(result), 1)
self.assertIs(result[0], skipped)
# ---------------------------------------------------------------------------
# flush
# ---------------------------------------------------------------------------
class TestFlush(unittest.TestCase):
def test_empty_queue_returns_empty(self):
self.assertEqual(_seq().flush(), [])
def test_stops_at_incomplete_spoken_slot(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
self.assertEqual(seq.flush(), [])
def test_last_word_pts_assigned_to_skipped_frame(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx2", None)
# process_word("hello") completes the spoken slot and calls flush(last_word_pts=77)
result = seq.process_word("hello", pts=77, context_id="ctx1")
flushed = [f for f in result if isinstance(f, AggregatedTextFrame) and f.text == "code"]
self.assertEqual(len(flushed), 1)
self.assertEqual(flushed[0].pts, 77)
def test_complete_spoken_slots_are_swept(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
seq.complete_spoken_slot()
# Queue should be empty after sweeping the complete spoken slot
self.assertEqual(seq._slots, [])
# ---------------------------------------------------------------------------
# process_word — basic
# ---------------------------------------------------------------------------
class TestProcessWordBasic(unittest.TestCase):
def _seq_with_spoken(self, text: str, ctx: str = "ctx1", append: bool = True):
seq = _seq()
seq.register_spoken(_spoken_frame(text), ctx, _tracker(text), append)
return seq
def test_returns_tts_text_frame(self):
seq = self._seq_with_spoken("hello")
result = seq.process_word("hello", pts=100, context_id="ctx1")
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TTSTextFrame)
def test_frame_text_and_pts(self):
seq = self._seq_with_spoken("hello")
result = seq.process_word("hello", pts=100, context_id="ctx1")
self.assertEqual(result[0].text, "hello")
self.assertEqual(result[0].pts, 100)
def test_frame_context_id(self):
seq = self._seq_with_spoken("hello", ctx="ctx99")
result = seq.process_word("hello", pts=1, context_id="ctx99")
self.assertEqual(result[0].context_id, "ctx99")
def test_append_to_context_true(self):
seq = self._seq_with_spoken("hello", append=True)
result = seq.process_word("hello", pts=1, context_id="ctx1")
self.assertTrue(result[0].append_to_context)
def test_append_to_context_false(self):
seq = self._seq_with_spoken("hello", append=False)
result = seq.process_word("hello", pts=1, context_id="ctx1")
self.assertFalse(result[0].append_to_context)
def test_non_completing_word_does_not_flush_skipped(self):
seq = self._seq_with_spoken("hello world")
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
result = seq.process_word("hello", pts=10, context_id="ctx1")
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TTSTextFrame)
def test_completing_word_flushes_blocked_skipped_frame(self):
seq = self._seq_with_spoken("hello")
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx2", None)
result = seq.process_word("hello", pts=50, context_id="ctx1")
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], TTSTextFrame)
self.assertIs(result[1], skipped)
def test_last_of_multiple_words_flushes_skipped(self):
seq = self._seq_with_spoken("hello world")
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx2", None)
seq.process_word("hello", pts=10, context_id="ctx1")
result = seq.process_word("world", pts=20, context_id="ctx1")
self.assertTrue(any(f is skipped for f in result))
def test_no_active_slot_emits_passthrough(self):
seq = _seq()
result = seq.process_word("hello", pts=1, context_id="ctx-unknown")
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], TTSTextFrame)
self.assertEqual(result[0].text, "hello")
self.assertEqual(result[0].context_id, "ctx-unknown")
def test_passthrough_uses_default_append_to_context_true(self):
seq = _seq()
result = seq.process_word("hello", pts=1, context_id="ctx-unknown")
self.assertTrue(result[0].append_to_context)
def test_unrecognised_word_emits_passthrough(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
# "zzz" doesn't belong to "hello world" and there is no next slot
result = seq.process_word("zzz", pts=5, context_id="ctx1")
self.assertEqual(len(result), 1)
self.assertEqual(result[0].text, "zzz")
# ---------------------------------------------------------------------------
# process_word — raw_text propagation
# ---------------------------------------------------------------------------
class TestProcessWordRawText(unittest.TestCase):
def test_raw_text_split_across_word_frames(self):
seq = _seq()
seq.register_spoken(
_spoken_frame("4111 1111"),
"ctx1",
WordCompletionTracker("4111 1111", llm_text="<card>4111 1111</card>"),
append_to_context=True,
)
r1 = seq.process_word("4111", pts=10, context_id="ctx1")
r2 = seq.process_word("1111", pts=20, context_id="ctx1")
self.assertEqual(r1[0].raw_text, "<card>4111")
last_word_frames = [f for f in r2 if isinstance(f, TTSTextFrame)]
self.assertEqual(last_word_frames[0].raw_text, "1111</card>")
def test_raw_text_none_when_no_llm_text(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
result = seq.process_word("hello", pts=1, context_id="ctx1")
self.assertIsNone(result[0].raw_text)
# ---------------------------------------------------------------------------
# process_word — overflow (single token spanning two slots)
# ---------------------------------------------------------------------------
class TestProcessWordOverflow(unittest.TestCase):
def test_overflow_produces_two_tts_text_frames(self):
seq = _seq()
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(len(word_frames), 2)
self.assertEqual(word_frames[0].text, "abc")
self.assertEqual(word_frames[1].text, "def")
def test_overflow_assigns_correct_context_ids(self):
seq = _seq()
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(word_frames[0].context_id, "ctx1")
self.assertEqual(word_frames[1].context_id, "ctx2")
def test_overflow_completing_next_slot_flushes_skipped(self):
seq = _seq()
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx3", None) # blocked behind ctx2
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
self.assertTrue(any(f is skipped for f in result))
def test_overflow_not_completing_next_slot_does_not_flush_skipped(self):
seq = _seq()
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
seq.register_spoken(_spoken_frame("def ghi"), "ctx2", _tracker("def ghi"), True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx3", None)
# "abcdef" overflows: "def" goes to ctx2, but ctx2 still expects " ghi"
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
self.assertFalse(any(f is skipped for f in result))
# ---------------------------------------------------------------------------
# process_word — force-complete via word_belongs_here failure
# ---------------------------------------------------------------------------
class TestProcessWordForcesComplete(unittest.TestCase):
def test_word_for_next_slot_force_completes_current(self):
"""When a word belongs to the next slot but not the current, the current
slot is force-completed and the word is routed to the next slot."""
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
# "world" doesn't belong to ctx1 but belongs to ctx2
result = seq.process_word("world", pts=50, context_id="ctx2")
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
texts = {f.text for f in word_frames}
self.assertIn("world", texts)
def test_force_complete_then_overflow_flushes_skipped(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx3", None)
# "world" force-completes ctx1 and completes ctx2 via overflow
result = seq.process_word("world", pts=50, context_id="ctx2")
self.assertTrue(any(f is skipped for f in result))
# ---------------------------------------------------------------------------
# force_complete
# ---------------------------------------------------------------------------
class TestForceComplete(unittest.TestCase):
def test_emits_remaining_text_when_word_dropped(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
seq.process_word("hello", pts=10, context_id="ctx1") # "world" never arrives
result = seq.force_complete(last_word_pts=50)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(len(tts_frames), 1)
self.assertEqual(tts_frames[0].text, "world")
self.assertEqual(tts_frames[0].pts, 50)
def test_emits_full_text_when_no_words_arrived(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
result = seq.force_complete(last_word_pts=0)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(len(tts_frames), 1)
self.assertEqual(tts_frames[0].text, "hello world")
def test_already_complete_slot_emits_nothing(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hi"), "ctx1", _tracker("hi"), True)
seq.process_word("hi", pts=5, context_id="ctx1") # completes normally
result = seq.force_complete(last_word_pts=10)
self.assertEqual(result, [])
def test_flushes_skipped_frames_after_completing(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx2", None)
result = seq.force_complete(last_word_pts=20)
self.assertTrue(any(f is skipped for f in result))
self.assertTrue(skipped.append_to_context)
def test_propagates_raw_text(self):
seq = _seq()
seq.register_spoken(
_spoken_frame("4111 1111"),
"ctx1",
WordCompletionTracker("4111 1111", llm_text="<card>4111 1111</card>"),
append_to_context=True,
)
seq.process_word("4111", pts=10, context_id="ctx1") # "1111" never arrives
result = seq.force_complete(last_word_pts=20)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(tts_frames[0].text, "1111")
self.assertEqual(tts_frames[0].raw_text, "1111</card>")
def test_discards_corrupt_raw_remaining(self):
"""raw_remaining is discarded when it does not contain remaining_text."""
seq = _seq()
# "abc" normalized ≠ "xyz" normalized — any remaining won't be in raw_remaining
seq.register_spoken(
_spoken_frame("abc"),
"ctx1",
WordCompletionTracker("abc", llm_text="xyz"),
append_to_context=True,
)
result = seq.force_complete(last_word_pts=0)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(len(tts_frames), 1)
self.assertEqual(tts_frames[0].text, "abc")
self.assertIsNone(tts_frames[0].raw_text) # discarded due to corruption
def test_slot_without_tracker_just_marks_complete_and_flushes(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
skipped = _skipped_frame("code")
seq.register_skipped(skipped, "ctx2", None)
result = seq.force_complete(last_word_pts=0)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(tts_frames, []) # no tracker → no word frame
self.assertTrue(any(f is skipped for f in result))
def test_multiple_incomplete_slots_all_emitted(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
result = seq.force_complete(last_word_pts=0)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
texts = {f.text for f in tts_frames}
self.assertIn("hello", texts)
self.assertIn("world", texts)
# ---------------------------------------------------------------------------
# clear
# ---------------------------------------------------------------------------
class TestClear(unittest.TestCase):
def test_clears_slots(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
seq.clear()
self.assertEqual(seq._slots, [])
def test_clears_context_map(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
seq.clear()
self.assertEqual(seq._context_append_to_context, {})
def test_after_clear_skipped_emits_immediately(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
seq.clear()
frame = _skipped_frame("code")
result = seq.register_skipped(frame, "ctx2", None)
self.assertEqual(len(result), 1)
def test_after_clear_process_word_uses_passthrough(self):
seq = _seq()
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
seq.clear()
result = seq.process_word("hello", pts=1, context_id="ctx1")
# No active slot after clear → passthrough
self.assertEqual(len(result), 1)
self.assertEqual(result[0].text, "hello")
# ---------------------------------------------------------------------------
# CJK languages — Korean, Japanese, Chinese
# ---------------------------------------------------------------------------
class TestCJKLanguages(unittest.TestCase):
"""Sequencer behaviour for CJK language scenarios.
Korean: Cartesia returns each word as a separate timestamp event (one word
per process_word call). Japanese/Chinese: Cartesia merges all characters
in one timestamp message into a single combined token before calling
process_word.
"""
# --- Korean ---
def test_korean_word_by_word_completes_slot_and_flushes_skipped(self):
"""Korean words fed one at a time complete the spoken slot and unblock a skipped frame."""
seq = _seq()
sentence = "저는 여러분의 AI 어시스턴트입니다."
words = ["저는", "여러분의", "AI", "어시스턴트입니다."]
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
skipped = _skipped_frame("[code]")
seq.register_skipped(skipped, "ctx2", None)
# Skipped stays blocked until the last word arrives
for word in words[:-1]:
partial = seq.process_word(word, pts=100, context_id="ctx1")
self.assertFalse(any(f is skipped for f in partial))
result = seq.process_word(words[-1], pts=200, context_id="ctx1")
self.assertTrue(any(f is skipped for f in result))
def test_korean_force_complete_emits_correct_remaining_text(self):
"""After one Korean word, force_complete emits the correct unspoken suffix."""
seq = _seq()
sentence = "저는 여러분의 AI 어시스턴트입니다."
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
seq.process_word("저는", pts=10, context_id="ctx1")
result = seq.force_complete(last_word_pts=50)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(len(tts_frames), 1)
self.assertEqual(tts_frames[0].text, "여러분의 AI 어시스턴트입니다.")
self.assertEqual(tts_frames[0].pts, 50)
# --- Japanese ---
def test_japanese_combined_groups_complete_spoken_slot(self):
"""Two Cartesia-style combined Japanese groups complete the slot and flush skipped."""
seq = _seq()
sentence = "こんにちは、私はあなたの"
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
skipped = _skipped_frame("[skipped]")
seq.register_skipped(skipped, "ctx2", None)
r1 = seq.process_word("こんにちは、私", pts=100, context_id="ctx1")
self.assertFalse(any(f is skipped for f in r1))
r2 = seq.process_word("はあなたの", pts=200, context_id="ctx1")
self.assertTrue(any(f is skipped for f in r2))
def test_japanese_force_complete_emits_remaining_chars(self):
"""After the first Japanese combined group, force_complete emits the rest."""
seq = _seq()
sentence = "こんにちは、私はあなたの"
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
seq.process_word("こんにちは、私", pts=10, context_id="ctx1")
result = seq.force_complete(last_word_pts=50)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(len(tts_frames), 1)
self.assertEqual(tts_frames[0].text, "はあなたの")
# --- Chinese ---
def test_chinese_combined_groups_complete_spoken_slot(self):
"""Two Cartesia-style combined Chinese groups complete the slot and flush skipped."""
seq = _seq()
sentence = "你好,我是你的智能"
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
skipped = _skipped_frame("[skipped]")
seq.register_skipped(skipped, "ctx2", None)
r1 = seq.process_word("你好,我是", pts=100, context_id="ctx1")
self.assertFalse(any(f is skipped for f in r1))
r2 = seq.process_word("你的智能", pts=200, context_id="ctx1")
self.assertTrue(any(f is skipped for f in r2))
def test_chinese_force_complete_emits_remaining_chars(self):
"""After the first Chinese combined group, force_complete emits the rest."""
seq = _seq()
sentence = "你好,我是你的智能"
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
seq.process_word("你好,我是", pts=10, context_id="ctx1")
result = seq.force_complete(last_word_pts=50)
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
self.assertEqual(len(tts_frames), 1)
self.assertEqual(tts_frames[0].text, "你的智能")
if __name__ == "__main__":
unittest.main()

View File

@@ -18,7 +18,7 @@ def _service(language: str) -> CartesiaTTSService:
def _process_word_timestamps(
words: list[str], starts: list[float], language: str
) -> list[tuple[str, float]]:
return _service(language)._process_word_timestamps_for_language(words, starts)
return _service(language)._normalize_word_timestamps(words, starts)
def _concatenate_processed_timestamps(
@@ -27,7 +27,7 @@ def _concatenate_processed_timestamps(
service = _service(language)
text_parts = []
for words, starts in timestamp_groups:
processed_timestamps = service._process_word_timestamps_for_language(words, starts)
processed_timestamps = service._normalize_word_timestamps(words, starts)
includes_inter_frame_spaces = service._word_timestamps_include_inter_frame_spaces()
text_parts.extend(
TextPartForConcatenation(

View File

@@ -21,6 +21,13 @@ repeated for each TTSSpeakFrame, with no cross-group contamination.
Also covers LLM response flow with push_text_frames=True (non-word-timestamp TTS):
verifies TTSTextFrame ordering relative to LLMFullResponseEndFrame.
Also covers smart-text / WordCompletionTracker features:
- Skipped frames (skip_aggregator_types) held until preceding spoken slots complete.
- raw_text on AggregatedTextFrame propagated as spans to TTSTextFrames.
- Overflow: a single TTS word straddling two AggregatedTextFrame boundaries produces
two correctly-attributed TTSTextFrames.
- Force-complete safety net: skipped frames flush even when TTS drops word timestamps.
Also covers the interruption-during-pause deadlock scenario (see test_no_deadlock_on_interrupt_*).
"""
@@ -50,6 +57,7 @@ from pipecat.frames.frames import (
)
from pipecat.services.tts_service import TTSService
from pipecat.tests.utils import SleepFrame, run_test
from pipecat.utils.text.base_text_aggregator import AggregationType
# ---------------------------------------------------------------------------
# Test-only frame
@@ -422,7 +430,7 @@ def _assert_group_ordering(
# All frames between TTSStartedFrame and TTSStoppedFrame must be audio.
mid_types = types[started_idx + 1 : stopped_idx]
for t in mid_types:
assert t is TTSAudioRawFrame, (
assert t in (TTSAudioRawFrame, TTSTextFrame), (
f"Group {foo_label!r}: unexpected frame {t.__name__!r} between "
f"TTSStartedFrame and TTSStoppedFrame. Got: {type_names}"
)
@@ -551,7 +559,7 @@ async def test_http_push_text_llm_response_end_after_tts_text():
@pytest.mark.asyncio
async def test_http_word_timestamps_verbatim_tokens():
"""HTTP path: text, PTS order, flag, and text-before-audio are all verified.
"""HTTP path: text, PTS order, and text-before-audio are all verified.
Word timestamps arrive in the audio context queue before the audio frame.
_handle_audio_context caches them, then flushes when the first audio frame
@@ -572,7 +580,6 @@ async def test_http_word_timestamps_verbatim_tokens():
audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)]
assert [f.text for f in tts_text_frames] == ["hello", "world"]
assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames)
pts_values = [f.pts for f in tts_text_frames]
assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), (
@@ -590,15 +597,14 @@ async def test_http_word_timestamps_verbatim_tokens():
@pytest.mark.asyncio
async def test_http_word_timestamps_punctuation_tokens():
"""Verbatim punctuation tokens are preserved with flag=True; default flag is False.
"""Punct-only tokens are merged into the preceding word when includes_inter_frame_spaces=True.
Models the Inworld API scenario: the TTS returns tokens exactly as sent.
Space placement rule:
- word-follows-word: space is the leading char of the next word (e.g. " world")
- word-follows-punctuation: space is the trailing char of the punctuation token
(e.g. "! "), so the following word token carries no leading space.
The flag must reach every frame and the text must not be modified.
Also acts as a regression guard that flag=False is the default.
Models the Inworld API scenario: the TTS returns separate space and punctuation
tokens. add_word_timestamps calls merge_punct_tokens when includes_inter_frame_spaces
is True, collapsing those tokens into the preceding word before the tracker sees them.
With flag=False (default) tokens are forwarded as-is; the tracker strips leading/
trailing whitespace from each frame word via get_word_for_frame().
"""
verbatim_tokens = [
("hello", 0.0),
@@ -609,9 +615,9 @@ async def test_http_word_timestamps_punctuation_tokens():
(" you", 0.75),
("?", 0.9),
]
expected_texts = ["hello", " world", "! ", "How", " are", " you", "?"]
# With flag=True: all tokens verbatim, all frames carry the flag.
# With flag=True: punct-only tokens ("! " and "?") are merged into the preceding
# words (" world" → " world! " and " you" → " you?"), then stripped by the tracker.
tts_ifs = _MockWordTimestampHttpTTSService(
includes_inter_frame_spaces=True,
word_times=verbatim_tokens,
@@ -621,12 +627,11 @@ async def test_http_word_timestamps_punctuation_tokens():
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
)
text_frames_ifs = [f for f in frames_ifs[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in text_frames_ifs] == expected_texts, (
"Verbatim tokens must not be modified"
assert [f.text for f in text_frames_ifs] == ["hello", "world!", "How", "are", "you?"], (
"Punct-only tokens must be merged into the preceding word"
)
assert all(f.includes_inter_frame_spaces is True for f in text_frames_ifs)
# With flag=False (default): same tokens, flag must be False on every frame.
# With flag=False (default): no merging; tracker strips leading/trailing spaces.
tts_plain = _MockWordTimestampHttpTTSService(
word_times=verbatim_tokens,
)
@@ -635,13 +640,12 @@ async def test_http_word_timestamps_punctuation_tokens():
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
)
text_frames_plain = [f for f in frames_plain[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in text_frames_plain] == expected_texts
assert all(f.includes_inter_frame_spaces is False for f in text_frames_plain)
assert [f.text for f in text_frames_plain] == ["hello", "world", "!", "How", "are", "you", "?"]
@pytest.mark.asyncio
async def test_websocket_word_timestamps_verbatim_tokens():
"""WebSocket path: _WordTimestampEntry carries verbatim text, PTS, and flag.
"""WebSocket path: text, PTS order, and text-before-audio are all verified.
Unlike the HTTP path the word timestamps are sent asynchronously from a
background task. They arrive before the audio frame and are cached until
@@ -662,7 +666,6 @@ async def test_websocket_word_timestamps_verbatim_tokens():
audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)]
assert [f.text for f in tts_text_frames] == ["hello", "world"]
assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames)
pts_values = [f.pts for f in tts_text_frames]
assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), (
@@ -678,7 +681,7 @@ async def test_websocket_word_timestamps_verbatim_tokens():
@pytest.mark.asyncio
async def test_websocket_word_timestamps_punctuation_tokens():
"""WebSocket path: verbatim punctuation tokens reach TTSTextFrame unchanged."""
"""WebSocket path: punct-only tokens are merged into the preceding word."""
verbatim_tokens = [
("hello", 0.0),
(" world", 0.15),
@@ -697,10 +700,443 @@ async def test_websocket_word_timestamps_punctuation_tokens():
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
)
text_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in text_frames] == ["hello", " world", "! ", "How", " are", " you", "?"], (
"Verbatim tokens must not be modified"
assert [f.text for f in text_frames] == ["hello", "world!", "How", "are", "you?"], (
"Punct-only tokens must be merged into the preceding word"
)
# ---------------------------------------------------------------------------
# Per-call word-timestamp mock (for overflow tests)
# ---------------------------------------------------------------------------
class _MockPerCallWordTimestampHttpTTSService(TTSService):
"""HTTP-style TTS where each run_tts() call consumes its own word-time list.
Designed for tests that need different word tokens per sentence. The
``word_times_per_call`` list is consumed in order; an empty inner list means
no word-timestamp events are emitted for that call.
"""
def __init__(
self,
word_times_per_call: list[list[tuple[str, float]]],
**kwargs,
):
super().__init__(
push_start_frame=True,
push_stop_frames=True,
push_text_frames=False,
sample_rate=_SAMPLE_RATE,
**kwargs,
)
self._word_times_queue = list(word_times_per_call)
def can_generate_metrics(self) -> bool:
return False
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
word_times = self._word_times_queue.pop(0) if self._word_times_queue else []
if word_times:
await self.add_word_timestamps(word_times, context_id=context_id)
yield TTSAudioRawFrame(
audio=_FAKE_AUDIO,
sample_rate=_SAMPLE_RATE,
num_channels=1,
context_id=context_id,
)
# ---------------------------------------------------------------------------
# Tests: skipped frame ordering (skip_aggregator_types)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_http_skipped_frame_waits_for_spoken_words():
"""Skipped frames are held until the preceding spoken slot's word timestamps
are all processed, then flushed in order (HTTP / synchronous audio path).
Sequence sent:
AggregatedTextFrame("hello world", SENTENCE) — spoken; yields 2 TTSTextFrames
AggregatedTextFrame("some code", "code") — in skip_aggregator_types; must wait
Expected downstream order:
TTSTextFrame("hello")
TTSTextFrame("world")
AggregatedTextFrame("some code", append_to_context=True)
"""
tts = _MockWordTimestampHttpTTSService(skip_aggregator_types=["code"])
frames_received = await run_test(
tts,
frames_to_send=[
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
AggregatedTextFrame("some code", "code"),
],
)
down = frames_received[0]
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
assert [f.text for f in word_frames] == ["hello", "world"]
assert len(skipped) == 1
assert skipped[0].append_to_context is True
last_word_idx = max(down.index(f) for f in word_frames)
skipped_idx = down.index(skipped[0])
assert skipped_idx > last_word_idx, (
f"Skipped frame (pos {skipped_idx}) must appear after last word frame (pos {last_word_idx})"
)
@pytest.mark.asyncio
async def test_ws_skipped_frame_waits_for_spoken_words():
"""Same ordering guarantee on the WebSocket / async audio delivery path.
Because audio is delivered from a background task after asyncio.sleep(), the
skipped frame arrives at _push_frame_respecting_previous_aggregated_frame
*before* the spoken slot's word timestamps have been processed, directly
exercising the hold-and-flush path.
"""
tts = _MockWordTimestampWSTTSService(skip_aggregator_types=["code"])
frames_received = await run_test(
tts,
frames_to_send=[
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
AggregatedTextFrame("some code", "code"),
],
)
down = frames_received[0]
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
assert [f.text for f in word_frames] == ["hello", "world"]
assert len(skipped) == 1
assert skipped[0].append_to_context is True
last_word_idx = max(down.index(f) for f in word_frames)
skipped_idx = down.index(skipped[0])
assert skipped_idx > last_word_idx, (
f"Skipped frame (pos {skipped_idx}) must appear after last word frame (pos {last_word_idx})"
)
@pytest.mark.asyncio
async def test_skipped_frame_before_spoken_emits_immediately():
"""A skipped frame with no preceding spoken slot is emitted right away.
Sequence:
AggregatedTextFrame("some code", "code") — no spoken slot before it → emits now
AggregatedTextFrame("hello world", SENTENCE) — spoken; TTSTextFrames follow
Expected: AggregatedTextFrame("some code") appears *before* TTSTextFrame("hello").
"""
tts = _MockWordTimestampHttpTTSService(skip_aggregator_types=["code"])
frames_received = await run_test(
tts,
frames_to_send=[
AggregatedTextFrame("some code", "code"),
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
],
)
down = frames_received[0]
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
assert len(skipped) == 1
assert skipped[0].append_to_context is True
assert len(word_frames) >= 1
skipped_idx = down.index(skipped[0])
first_word_idx = down.index(word_frames[0])
assert skipped_idx < first_word_idx, (
f"Skipped frame (pos {skipped_idx}) must appear before first word frame (pos {first_word_idx})"
)
@pytest.mark.asyncio
async def test_skipped_frame_flushed_when_word_timestamps_incomplete():
"""Force-complete path: skipped frame still emits when the TTS drops word timestamps.
Only one of the two expected tokens ("hello") is returned. The spoken slot never
reaches its expected character count through the normal path. When
on_audio_context_done fires it force-completes any remaining spoken slots and
flushes the waiting skipped frame.
"""
tts = _MockWordTimestampHttpTTSService(
word_times=[("hello", 0.0)], # "world" is never sent
skip_aggregator_types=["code"],
)
frames_received = await run_test(
tts,
frames_to_send=[
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
AggregatedTextFrame("some code", "code"),
],
)
down = frames_received[0]
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
assert len(skipped) == 1, "Skipped frame must be flushed via force-complete safety net"
assert skipped[0].append_to_context is True
# ---------------------------------------------------------------------------
# Tests: raw_text propagation through WordCompletionTracker
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_raw_text_propagated_to_tts_text_frames():
"""raw_text on AggregatedTextFrame is split across TTSTextFrames by the tracker.
The frame carries raw_text="<card>4111 1111</card>" while the TTS-prepared
text is "4111 1111". The WordCompletionTracker advances a cursor through the
raw text in step with incoming word tokens, so each TTSTextFrame receives the
exact raw span it represents.
Expected (trailing whitespace stripped because includes_inter_frame_spaces=False):
TTSTextFrame("4111").raw_text == "<card>4111"
TTSTextFrame("1111").raw_text == "1111</card>"
"""
tts = _MockWordTimestampHttpTTSService()
frames_received = await run_test(
tts,
frames_to_send=[
AggregatedTextFrame(
"4111 1111", AggregationType.SENTENCE, raw_text="<card>4111 1111</card>"
)
],
)
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in word_frames] == ["4111", "1111"]
# get_raw_consumed() strips trailing whitespace when includes_inter_frame_spaces=False
assert word_frames[0].raw_text == "<card>4111"
assert word_frames[1].raw_text == "1111</card>"
# ---------------------------------------------------------------------------
# Tests: overflow — TTS word spanning two AggregatedTextFrame boundaries
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_overflow_word_spanning_two_aggregated_frames():
"""A single TTS token straddling two AggregatedTextFrame boundaries produces
two correctly-attributed TTSTextFrames.
Setup:
Frame 1: AggregatedTextFrame("abc", SENTENCE)
Frame 2: AggregatedTextFrame("def", SENTENCE)
The TTS for frame 1 returns the single token "abcdef", which overshoots
frame 1 by three characters. _emit_overflow_word splits it:
TTSTextFrame("abc") — frame 1's portion (context_id = ctx1)
TTSTextFrame("def") — overflow attributed to frame 2 (context_id = ctx2)
Frame 2 receives no word-timestamp events because the overflow already
consumed its expected text.
"""
tts = _MockPerCallWordTimestampHttpTTSService(
word_times_per_call=[
[("abcdef", 0.0)], # frame 1: single token spanning both frames
[], # frame 2: no word timestamps (overflow already covered it)
]
)
frames_received = await run_test(
tts,
frames_to_send=[
AggregatedTextFrame("abc", AggregationType.SENTENCE),
AggregatedTextFrame("def", AggregationType.SENTENCE),
],
)
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in word_frames] == ["abc", "def"], (
f"Expected ['abc', 'def'] but got {[f.text for f in word_frames]}"
)
assert word_frames[0].context_id != word_frames[1].context_id, (
"Overflow TTSTextFrame must carry frame 2's context_id, not frame 1's"
)
# ---------------------------------------------------------------------------
# Per-call word-timestamp mock for WebSocket path (for force-complete tests)
# ---------------------------------------------------------------------------
class _MockPerCallWordTimestampWSTTSService(TTSService):
"""WebSocket-style TTS where each run_tts() call consumes its own word-time list.
Mirrors _MockPerCallWordTimestampHttpTTSService but uses the async audio-context
delivery pattern so it exercises _handle_audio_context (the WebSocket path).
An empty inner list means no word-timestamp events are emitted for that call.
"""
def __init__(
self,
word_times_per_call: list[list[tuple[str, float]]],
**kwargs,
):
super().__init__(
push_start_frame=True,
push_text_frames=False,
pause_frame_processing=False,
sample_rate=_SAMPLE_RATE,
**kwargs,
)
self._word_times_queue = list(word_times_per_call)
def can_generate_metrics(self) -> bool:
return False
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
word_times = self._word_times_queue.pop(0) if self._word_times_queue else []
async def _deliver():
await asyncio.sleep(0.01)
if word_times:
await self.add_word_timestamps(word_times, context_id=context_id)
await self.append_to_audio_context(
context_id,
TTSAudioRawFrame(
audio=_FAKE_AUDIO,
sample_rate=_SAMPLE_RATE,
num_channels=1,
context_id=context_id,
),
)
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
await self.remove_audio_context(context_id)
self.create_task(_deliver(), name=f"mock_ws_per_call_deliver_{context_id}")
if False:
yield
# ---------------------------------------------------------------------------
# Tests: _force_complete_spoken_slots — TTSTextFrame emission for dropped timestamps
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_http_force_complete_partial_timestamps_emits_remaining_text():
"""_force_complete_spoken_slots emits a TTSTextFrame for the unspoken word suffix.
Only the first token ("hello") is delivered as a word-timestamp event; "world"
is never sent. When the audio context ends _force_complete_spoken_slots fires,
reads get_remaining_text() from the tracker, and emits TTSTextFrame("world").
Expected TTSTextFrames in order: ["hello", "world"].
"""
tts = _MockWordTimestampHttpTTSService(word_times=[("hello", 0.0)])
frames_received = await run_test(
tts,
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
)
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in word_frames] == ["hello", "world"], (
f"Expected ['hello', 'world'] but got {[f.text for f in word_frames]}"
)
@pytest.mark.asyncio
async def test_http_force_complete_no_timestamps_emits_full_text():
"""_force_complete_spoken_slots emits the full text when no word timestamps arrive.
No word-timestamp events are sent for "hello world". The slot remains incomplete
when the audio context ends; force-complete reads the full remaining text from the
tracker and emits TTSTextFrame("hello world").
"""
tts = _MockPerCallWordTimestampHttpTTSService(word_times_per_call=[[]])
frames_received = await run_test(
tts,
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
)
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert len(word_frames) == 1, (
f"Expected exactly 1 TTSTextFrame, got {len(word_frames)}: {[f.text for f in word_frames]}"
)
assert word_frames[0].text == "hello world", (
f"Expected TTSTextFrame('hello world'), got {word_frames[0].text!r}"
)
@pytest.mark.asyncio
async def test_http_force_complete_raw_text_propagated():
"""force-complete carries the correct raw_text span on the emitted TTSTextFrame.
AggregatedTextFrame carries raw_text="<card>4111 1111</card>". Only "4111" arrives
as a word-timestamp; "1111" is force-completed.
Expected:
TTSTextFrame("4111").raw_text == "<card>4111" — from normal word path
TTSTextFrame("1111").raw_text == "1111</card>" — from force-complete path
"""
tts = _MockPerCallWordTimestampHttpTTSService(word_times_per_call=[[("4111", 0.0)]])
frames_received = await run_test(
tts,
frames_to_send=[
AggregatedTextFrame(
"4111 1111", AggregationType.SENTENCE, raw_text="<card>4111 1111</card>"
)
],
)
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in word_frames] == ["4111", "1111"], (
f"Expected ['4111', '1111'] but got {[f.text for f in word_frames]}"
)
assert word_frames[0].raw_text == "<card>4111", (
f"Expected raw_text '<card>4111' on first frame, got {word_frames[0].raw_text!r}"
)
assert word_frames[1].raw_text == "1111</card>", (
f"Expected raw_text '1111</card>' on force-complete frame, got {word_frames[1].raw_text!r}"
)
@pytest.mark.asyncio
async def test_ws_force_complete_partial_timestamps_emits_remaining_text():
"""WebSocket path: _force_complete_spoken_slots emits TTSTextFrame for dropped token.
Mirrors test_http_force_complete_partial_timestamps_emits_remaining_text on the
async audio delivery path to confirm force-complete fires correctly from
_handle_audio_context when TTSStoppedFrame arrives before all word timestamps.
"""
tts = _MockWordTimestampWSTTSService(word_times=[("hello", 0.0)])
frames_received = await run_test(
tts,
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
)
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in word_frames] == ["hello", "world"], (
f"Expected ['hello', 'world'] but got {[f.text for f in word_frames]}"
)
@pytest.mark.asyncio
async def test_ws_force_complete_no_timestamps_emits_full_text():
"""WebSocket path: full text emitted as single TTSTextFrame when no timestamps arrive."""
tts = _MockPerCallWordTimestampWSTTSService(word_times_per_call=[[]])
frames_received = await run_test(
tts,
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
)
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert len(word_frames) == 1, (
f"Expected exactly 1 TTSTextFrame, got {len(word_frames)}: {[f.text for f in word_frames]}"
)
assert word_frames[0].text == "hello world", (
f"Expected TTSTextFrame('hello world'), got {word_frames[0].text!r}"
)
assert all(f.includes_inter_frame_spaces is True for f in text_frames)
@pytest.mark.asyncio

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,90 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.utils.text.word_timestamp_utils import merge_punct_tokens
class TestMergePunctTokens(unittest.TestCase):
def test_empty_list(self):
self.assertEqual(merge_punct_tokens([]), [])
def test_all_alnum_words_pass_through(self):
input = [("hello", 0.0), ("world", 1.0)]
self.assertEqual(merge_punct_tokens(input), [("hello", 0.0), ("world", 1.0)])
def test_trailing_space_merged_and_stripped(self):
input = [("I", 0.0), (" ", 0.2)]
self.assertEqual(merge_punct_tokens(input), [("I", 0.0)])
def test_comma_space_merged_and_stripped(self):
input = [("questions", 1.0), (", ", 1.2), ("explain", 1.4)]
self.assertEqual(merge_punct_tokens(input), [("questions,", 1.0), ("explain", 1.4)])
def test_leading_space_with_no_preceding_word_discarded(self):
input = [(" ", 0.0), ("hello", 0.5)]
self.assertEqual(merge_punct_tokens(input), [("hello", 0.5)])
def test_leading_empty_string_discarded(self):
input = [("", 0.0), ("hello", 0.5)]
self.assertEqual(merge_punct_tokens(input), [("hello", 0.5)])
def test_multiple_consecutive_punct_tokens_merged_and_stripped(self):
input = [("word", 0.0), (",", 0.1), (" ", 0.2), ("next", 0.3)]
self.assertEqual(merge_punct_tokens(input), [("word,", 0.0), ("next", 0.3)])
def test_timestamp_of_preceding_word_is_kept(self):
"""Merged punct tokens adopt the preceding word's timestamp."""
input = [("hello", 2.5), (",", 2.7)]
result = merge_punct_tokens(input)
self.assertEqual(result, [("hello,", 2.5)])
def test_xml_tag_only_token_is_treated_as_punct(self):
"""A token that is only an XML tag (no alnum chars) merges into the preceding word."""
input = [("word", 0.0), ("<break/>", 0.1), ("next", 0.3)]
self.assertEqual(merge_punct_tokens(input), [("word<break/>", 0.0), ("next", 0.3)])
def test_xml_tag_with_alnum_content_passes_through(self):
"""A token like '<spell>123</spell>' has alnum chars after stripping tags."""
input = [("<spell>123</spell>", 0.0), ("and", 0.5)]
self.assertEqual(merge_punct_tokens(input), [("<spell>123</spell>", 0.0), ("and", 0.5)])
def test_inworld_style_full_stream(self):
"""Full Inworld-style raw stream produces expected merged and stripped output."""
raw = [
("", 0.0),
("I", 0.1),
(" ", 0.2),
("can", 0.3),
(" ", 0.4),
("answer", 0.5),
(" ", 0.6),
("questions", 0.7),
(", ", 0.8),
("explain", 0.9),
(" ", 1.0),
("things", 1.1),
(".", 1.2),
]
expected = [
("I", 0.1),
("can", 0.3),
("answer", 0.5),
("questions,", 0.7),
("explain", 0.9),
("things.", 1.1),
]
self.assertEqual(merge_punct_tokens(raw), expected)
def test_only_punct_tokens_returns_empty(self):
"""A list containing only punct/space tokens produces an empty result."""
input = [(" ", 0.0), (",", 0.1), (".", 0.2)]
self.assertEqual(merge_punct_tokens(input), [])
if __name__ == "__main__":
unittest.main()