diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 54de1c807..decb76d75 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -383,10 +383,14 @@ class AggregatedTextFrame(TextFrame): Parameters: aggregated_by: Method used to aggregate the text frames. context_id: Unique identifier for the TTS context that generated this text. + raw_text: The full matched text including start/end pattern delimiters, set when + this frame was produced from a PatternMatch (e.g. a ``...`` block). + None for ordinary sentence aggregations. """ aggregated_by: AggregationType | str context_id: str | None = None + raw_text: str | None = None @dataclass diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index b50eac1f7..bc158042e 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -25,6 +25,7 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.audio.vad.vad_analyzer import VADAnalyzer from pipecat.audio.vad.vad_controller import VADController from pipecat.frames.frames import ( + AggregatedTextFrame, AssistantImageRawFrame, BotStartedSpeakingFrame, BotStoppedSpeakingFrame, @@ -1496,9 +1497,14 @@ class LLMAssistantAggregator(LLMContextAggregator): if len(frame.text) == 0: return + text = ( + frame.raw_text + if isinstance(frame, AggregatedTextFrame) and frame.raw_text + else frame.text + ) self._aggregation.append( TextPartForConcatenation( - frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces + text, includes_inter_part_spaces=frame.includes_inter_frame_spaces ) ) diff --git a/src/pipecat/processors/aggregators/llm_text_processor.py b/src/pipecat/processors/aggregators/llm_text_processor.py index 862cf138b..9a23c01b3 100644 --- a/src/pipecat/processors/aggregators/llm_text_processor.py +++ b/src/pipecat/processors/aggregators/llm_text_processor.py @@ -23,6 +23,7 @@ from pipecat.frames.frames import ( ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.pattern_pair_aggregator import PatternMatch from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator @@ -85,7 +86,11 @@ class LLMTextProcessor(FrameProcessor): out_frame = AggregatedTextFrame( text=aggregation.text, aggregated_by=aggregation.type, + raw_text=aggregation.full_match + if isinstance(aggregation, PatternMatch) + else aggregation.text, ) + out_frame.append_to_context = True out_frame.skip_tts = in_frame.skip_tts await self.push_frame(out_frame) @@ -96,6 +101,9 @@ class LLMTextProcessor(FrameProcessor): out_frame = AggregatedTextFrame( text=remaining.text, aggregated_by=remaining.type, + raw_text=remaining.full_match + if isinstance(remaining, PatternMatch) + else remaining.text, ) out_frame.skip_tts = skip_tts await self.push_frame(out_frame) diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index 61eddb20c..eb7210f19 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -50,9 +50,13 @@ from pipecat.services.ai_service import AIService from pipecat.services.settings import TTSSettings, is_given from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language +from pipecat.utils.context.aggregated_frame_sequencer import AggregatedFrameSequencer +from pipecat.utils.context.word_completion_tracker import WordCompletionTracker from pipecat.utils.frame_queue import FrameQueue from pipecat.utils.text.base_text_filter import BaseTextFilter +from pipecat.utils.text.pattern_pair_aggregator import PatternMatch from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator +from pipecat.utils.text.word_timestamp_utils import merge_punct_tokens from pipecat.utils.time import seconds_to_nanoseconds @@ -97,7 +101,6 @@ class _WordTimestampEntry: word: str timestamp: float context_id: str - includes_inter_frame_spaces: bool | None = None class TTSService(AIService): @@ -289,6 +292,16 @@ class TTSService(AIService): self._text_filters: Sequence[BaseTextFilter] = text_filters or [] self._transport_destination: str | None = transport_destination + # Ordered sequence of every AggregatedTextFrame slot that passes through + # _push_tts_frames (both spoken and skipped). Skipped frames are held here + # until all preceding spoken slots are complete, then flushed downstream so + # their append_to_context=True arrives at the assistant aggregator in the + # correct order relative to the TTSTextFrames from spoken sentences. + # Tracks all AggregatedTextFrame slots (spoken and skipped) in order. + # Skipped frames are held until preceding spoken slots complete, ensuring + # append_to_context=True reaches the assistant aggregator in the right order. + self._aggregated_frame_sequencer = AggregatedFrameSequencer(name=str(self)) + self._resampler = create_stream_resampler() self._processing_text: bool = False @@ -690,7 +703,15 @@ class TTSService(AIService): # Stop the aggregation metric (no-op if already stopped on first sentence). await self.stop_text_aggregation_metrics() if remaining: - await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type)) + await self._push_tts_frames( + AggregatedTextFrame( + remaining.text, + remaining.type, + raw_text=remaining.full_match + if isinstance(remaining, PatternMatch) + else remaining.text, + ) + ) # We pause processing incoming frames if the LLM response included # text (it might be that it's only a function calling response). We @@ -733,7 +754,7 @@ class TTSService(AIService): push_assistant_aggregation = frame.append_to_context and not self._llm_response_started # Assumption: text in TTSSpeakFrame does not include inter-frame spaces await self._push_tts_frames( - AggregatedTextFrame(frame.text, AggregationType.SENTENCE), + AggregatedTextFrame(frame.text, AggregationType.SENTENCE, raw_text=frame.text), append_tts_text_to_context=frame.append_to_context, push_assistant_aggregation=push_assistant_aggregation, ) @@ -887,6 +908,7 @@ class TTSService(AIService): self._llm_response_started = False self._streamed_text = "" self._text_aggregation_metrics_started = False + self._aggregated_frame_sequencer.clear() # discard all pending slots on interruption await self.reset_word_timestamps() await self._stop_audio_context_task() @@ -930,9 +952,23 @@ class TTSService(AIService): if aggregate.type != AggregationType.TOKEN: # Stop the aggregation metric on the first sentence only. await self.stop_text_aggregation_metrics() - await self._push_tts_frames( - AggregatedTextFrame(aggregate.text, aggregate.type), includes_inter_frame_spaces + raw_text = ( + aggregate.full_match if isinstance(aggregate, PatternMatch) else aggregate.text ) + await self._push_tts_frames( + AggregatedTextFrame(aggregate.text, aggregate.type, raw_text=raw_text), + includes_inter_frame_spaces, + ) + + async def _push_frame_respecting_previous_aggregated_frame( + self, frame: AggregatedTextFrame, context_id: str + ): + # Enqueue the skipped frame; returns it immediately if no spoken slot + # precedes it, or holds it until the sequencer can flush it in order. + for f in self._aggregated_frame_sequencer.register_skipped( + frame, context_id, self._transport_destination + ): + await self.push_frame(f) async def _push_tts_frames( self, @@ -944,10 +980,13 @@ class TTSService(AIService): type = src_frame.aggregated_by text = src_frame.text + # Create context ID and store metadata + context_id = self.create_context_id() + # Skip sending to TTS if the aggregation type is in the skip list. Simply # push the original frame downstream. if type in self._skip_aggregator_types: - await self.push_frame(src_frame) + await self._push_frame_respecting_previous_aggregated_frame(src_frame, context_id) return # Whitespace gating depends on aggregation mode: @@ -998,9 +1037,6 @@ class TTSService(AIService): await self.stop_processing_metrics() return - # Create context ID and store metadata - context_id = self.create_context_id() - # To support use cases that may want to know the text before it's spoken, we # push the AggregatedTextFrame version before transforming and sending to TTS. # However, we do not want to add this text to the assistant context until it @@ -1045,6 +1081,21 @@ class TTSService(AIService): await self.start_ttfb_metrics() await self.append_to_audio_context(context_id, TTSStartedFrame(context_id=context_id)) + # Register this spoken frame so the sequencer can track its completion + # and unblock any skipped frames queued behind it. Word-timestamp services + # complete the slot via process_word; push_text_frames services complete it + # below after the TTSTextFrame is appended to the audio context. + self._aggregated_frame_sequencer.register_spoken( + src_frame, + context_id, + tracker=WordCompletionTracker( + prepared_text, llm_text=src_frame.raw_text or src_frame.text + ) + if not self._push_text_frames + else None, + append_to_context=self._tts_contexts[context_id].append_to_context, + ) + await self.tts_process_generator(context_id, self.run_tts(prepared_text, context_id)) if not self._is_streaming_tokens: @@ -1066,6 +1117,10 @@ class TTSService(AIService): frame.append_to_context = append_tts_text_to_context # Appending to the context, so it preserves the ordering. await self.append_to_audio_context(context_id, frame) + # TTSTextFrame is queued; mark the spoken slot complete so any skipped + # frames (e.g. code blocks) waiting behind it can be flushed in order. + for f in self._aggregated_frame_sequencer.complete_spoken_slot(): + await self.push_frame(f) async def tts_process_generator( self, context_id: str, generator: AsyncGenerator[Frame | None, None] @@ -1114,10 +1169,8 @@ class TTSService(AIService): if self._initial_word_times: cached = self._initial_word_times.copy() self._initial_word_times = [] - for word, timestamp_seconds, ctx_id, ifs in cached: - await self._add_word_timestamps( - [(word, timestamp_seconds)], ctx_id, includes_inter_frame_spaces=ifs - ) + for word, timestamp_seconds, ctx_id in cached: + await self._add_word_timestamps([(word, timestamp_seconds)], ctx_id) async def reset_word_timestamps(self): """Reset word timestamp tracking.""" @@ -1139,6 +1192,11 @@ class TTSService(AIService): playback order by _handle_audio_context. Otherwise they are processed immediately via _add_word_timestamps. + When ``includes_inter_frame_spaces`` is True (e.g. Inworld TTS), punctuation and + space-only tokens are merged into the preceding word via ``_merge_punct_tokens`` + before queuing, so the tracker always receives words with trailing punctuation + already attached. ``includes_inter_frame_spaces`` is reset to None after merging. + Args: word_times: List of (word, timestamp) tuples where timestamp is in seconds. context_id: Unique identifier for the TTS context. @@ -1147,29 +1205,22 @@ class TTSService(AIService): consumers must not inject additional spaces between tokens. None leaves the frame's own default unchanged. """ + if includes_inter_frame_spaces: + word_times = merge_punct_tokens(word_times) + if context_id and self.audio_context_available(context_id): for word, timestamp in word_times: await self.append_to_audio_context( context_id, - _WordTimestampEntry( - word=word, - timestamp=timestamp, - context_id=context_id, - includes_inter_frame_spaces=includes_inter_frame_spaces, - ), + _WordTimestampEntry(word=word, timestamp=timestamp, context_id=context_id), ) else: - await self._add_word_timestamps( - word_times=word_times, - context_id=context_id, - includes_inter_frame_spaces=includes_inter_frame_spaces, - ) + await self._add_word_timestamps(word_times=word_times, context_id=context_id) async def _add_word_timestamps( self, word_times: list[tuple[str, float]], context_id: str | None = None, - includes_inter_frame_spaces: bool | None = None, ): """Process word timestamps directly, building and pushing TTSTextFrames inline. @@ -1185,19 +1236,15 @@ class TTSService(AIService): ts_ns = seconds_to_nanoseconds(timestamp) if self._initial_word_timestamp == -1: # Cache until we have audio and can compute PTS. - self._initial_word_times.append( - (word, timestamp, context_id, includes_inter_frame_spaces) - ) + self._initial_word_times.append((word, timestamp, context_id)) else: - frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD) - if includes_inter_frame_spaces is not None: - frame.includes_inter_frame_spaces = includes_inter_frame_spaces - frame.pts = self._initial_word_timestamp + ts_ns - frame.context_id = context_id - if context_id in self._tts_contexts: - frame.append_to_context = self._tts_contexts[context_id].append_to_context - self._word_last_pts = frame.pts - await self.push_frame(frame) + pts = self._initial_word_timestamp + ts_ns + # Build TTSTextFrame(s) for this word token, advancing the active + # slot's tracker and flushing any skipped frames now unblocked. + for f in self._aggregated_frame_sequencer.process_word(word, pts, context_id): + if isinstance(f, TTSTextFrame): + self._word_last_pts = f.pts + await self.push_frame(f) # # Audio context methods (active when using websocket-based TTS with context management) @@ -1382,6 +1429,18 @@ class TTSService(AIService): frame.pts = self._word_last_pts await self.push_frame(frame) + async def _apply_force_complete(self): + """Force-complete all incomplete spoken slots and push any unblocked skipped frames. + + Called at end-of-context to handle TTS providers that silently drop word-timestamp + events. Emits a TTSTextFrame for any remaining unspoken text, then flushes skipped + frames that were blocked by those incomplete slots. + """ + for f in self._aggregated_frame_sequencer.force_complete(self._word_last_pts): + if isinstance(f, TTSTextFrame): + self._word_last_pts = f.pts + await self.push_frame(f) + async def _handle_audio_context(self, context_id: str): """Process items from an audio context queue until it is exhausted.""" queue = self._audio_contexts[context_id] @@ -1402,7 +1461,6 @@ class TTSService(AIService): await self._add_word_timestamps( [(frame.word, frame.timestamp)], frame.context_id, - includes_inter_frame_spaces=frame.includes_inter_frame_spaces, ) continue elif isinstance(frame, TTSAudioRawFrame): @@ -1416,6 +1474,9 @@ class TTSService(AIService): if isinstance(frame, TTSStartedFrame): should_push_stop_frame = self._push_stop_frames elif isinstance(frame, TTSStoppedFrame): + # Checking if we have any remaining spoken slots before pushing the TTSStoppedFrame + await self._apply_force_complete() + should_push_stop_frame = False # Setting the last word timestamp as the TTSStoppedFrame PTS if not frame.pts: @@ -1433,8 +1494,11 @@ class TTSService(AIService): should_push_stop_frame = False break + await self._apply_force_complete() + if should_push_stop_frame and self._push_stop_frames: await self.push_frame(TTSStoppedFrame(context_id=context_id)) + await self._maybe_reset_word_timestamps() async def on_audio_context_interrupted(self, context_id: str): diff --git a/src/pipecat/utils/context/word_completion_tracker.py b/src/pipecat/utils/context/word_completion_tracker.py new file mode 100644 index 000000000..1d12769bf --- /dev/null +++ b/src/pipecat/utils/context/word_completion_tracker.py @@ -0,0 +1,490 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Word completion tracker for TTS context ordering.""" + +import re +import unicodedata + +from loguru import logger + + +class WordCompletionTracker: + """Tracks whether all words from a source AggregatedTextFrame have been spoken. + + Compares normalized alphanumeric character counts between the TTS text and + accumulated spoken words, making the check robust to punctuation, spacing, + and XML/HTML tags (e.g. SSML tags like ``...`` returned by some + TTS providers in word-timestamp events). + + When ``llm_text`` is provided (e.g. the original pattern-matched text including + delimiters like ``4111 1111 1111 1111``), the tracker additionally + maps each spoken word back to its corresponding span in that LLM text. This + lets callers attach the original text to ``TTSTextFrame`` entries so the + conversation context receives properly-tagged content rather than the cleaned + words received from the TTS provider. + + Background: TTS providers apply their own SSML tags to the text before + synthesis and return word-timestamp events containing the raw spoken words + (e.g. ``"4111"``, ``"1111"``). Without LLM-text tracking, the conversation + context would only see those cleaned words and lose the original structure + (e.g. ``4111 1111 1111 1111``). By mapping normalized char counts + back to positions in ``llm_text``, each TTSTextFrame can carry the exact span + of original text it represents. + + Overflow handling: TTS providers sometimes return a single word token that + spans the boundary between two AggregatedTextFrames (e.g. ``"1111And"`` + when one frame ends with ``1111`` and the next begins with ``And``). The + tracker detects this and exposes the raw overflow suffix via ``get_overflow_word()``, + so callers can feed the remainder into the next frame's tracker and emit a + correctly-attributed TTSTextFrame for each part. + + Example:: + + tracker = WordCompletionTracker("Hello, world!") + tracker.add_word_and_check_complete("Hello") # False + tracker.add_word_and_check_complete("world") # True — normalized "helloworld" >= "helloworld" + """ + + def __init__( + self, + tts_text: str, + llm_text: str | None = None, + ): + """Initialize the tracker with the text of the frame being spoken. + + Args: + tts_text: Full text of the AggregatedTextFrame sent to TTS (may include + TTS-specific SSML tags). Used for normalized char-count completion + tracking and as the cursor reference for the TTS word stream. + llm_text: Original LLM-produced text including pattern delimiters (e.g. + ``4111 1111 1111 1111``). When provided, each + ``add_word_and_check_complete`` call also returns the corresponding + LLM span via ``get_llm_consumed()``. Both texts normalize to the + same alphanumeric sequence, so the same char-count cursor drives + position tracking in both. + """ + self._tts_normalized = self._normalize(tts_text) + self._received = "" + + # _tts_text is the original tts_text before normalization. + # _tts_pos is a cursor into it, advanced by the same alnum count + # as the TTS word stream, so the force-complete path can emit the remaining + # unspoken text as a TTSTextFrame instead of silently dropping it. + self._tts_text = tts_text + self._tts_pos = 0 + + # _llm_text is the original LLM-produced text (with pattern delimiters like + # ...). We track _llm_pos as a cursor into it, advancing + # by the same number of alphanumeric chars consumed from the TTS word stream. + self._llm_text = llm_text + self._llm_pos = 0 + + self._overflow_word: str | None = None + self._llm_consumed: str | None = None + self._frame_word: str | None = None + logger.debug(f"WordCompletionTracker: {self._tts_normalized}") + + @staticmethod + def _normalize(text: str) -> str: + """Strip XML/HTML tags then keep only lowercase alphanumeric characters. + + Accented letters (e.g. ã, é) are reduced to their base letter so TTS output + can be matched against LLM text even when the provider strips diacritics. + Non-Latin scripts (CJK, Hangul) are kept as-is — each original character + contributes exactly one char to the result, keeping normalized length in sync + with raw alnum counts used by _advance_by_alnums. + """ + text = re.sub(r"<[^>]+>", "", text) + result = [] + for char in text: + # Ignore punctuation, spaces, emojis, etc. + # Keep only letters and numbers. + if not char.isalnum(): + continue + # NFD decomposes accented characters into: + # é -> e + ◌́ + # ã -> a + ◌̃ + # + # Non-accented characters usually stay unchanged. + nfd = unicodedata.normalize("NFD", char) + # Unicode category "Mn" means: + # Mark, Nonspacing + # + # These are combining accent marks that modify + # the previous character but are not standalone. + # + # Example: + # "é" becomes: + # nfd[0] = "e" + # nfd[1] = "◌́" (category = "Mn") + # + # If the second character is a combining accent, + # keep only the base letter. + if len(nfd) >= 2 and unicodedata.category(nfd[1]) == "Mn": + # Accented letter: keep the base character only (drops the combining mark). + result.append(nfd[0].lower()) + else: + # Regular ASCII, numbers, CJK, Hangul, etc. + # are kept unchanged (except lowercase conversion). + result.append(char.lower()) + return "".join(result) + + # Typographic variants that LLMs commonly emit but TTS services normalize away. + _TYPOGRAPHY_FOLD = str.maketrans( + { + "‘": "'", # ' LEFT SINGLE QUOTATION MARK + "’": "'", # ' RIGHT SINGLE QUOTATION MARK + "ʼ": "'", # ʼ MODIFIER LETTER APOSTROPHE + "“": '"', # " LEFT DOUBLE QUOTATION MARK + "”": '"', # " RIGHT DOUBLE QUOTATION MARK + "–": "-", # – EN DASH + "—": "-", # — EM DASH + } + ) + + @staticmethod + def _fold_typography(text: str) -> str: + """Replace typographic punctuation variants with their ASCII equivalents.""" + return text.translate(WordCompletionTracker._TYPOGRAPHY_FOLD) + + @staticmethod + def _remove_trailing_punctuation(text: str) -> str: + """Remove punctuation only at the very end of the given text.""" + i = len(text) + while i > 0 and unicodedata.category(text[i - 1]).startswith("P"): + i -= 1 + return text[:i] + + @staticmethod + def _advance_by_alnums(text: str, start_pos: int, n: int) -> int: + """Return the position in *text* after advancing past *n* alphanumeric chars. + + Moves through the text one character at a time, counting only alphanumeric + characters. XML/HTML tags (``<...>``) are skipped entirely — their content + is not counted against the budget, so the returned span includes the full tag. + Other non-alphanumeric characters (spaces, punctuation) are also passed over + without decrementing the budget. + + After the *n* alnum chars are consumed, advances further past any immediately + following punctuation (e.g. the ``,`` in ``"questions,"`` or the ``.`` in + ``"done."``), stopping before the next space, alnum char, or XML tag. + + Args: + text: The source text to scan. + start_pos: Starting position in *text*. + n: Number of alphanumeric characters to consume. + """ + pos = start_pos + count = 0 + while pos < len(text) and count < n: + if text[pos] == "<": + end = text.find(">", pos) + pos = end + 1 if end != -1 else pos + 1 + elif text[pos].isalnum(): + count += 1 + pos += 1 + else: + pos += 1 + + while pos < len(text): + if text[pos] == "<": + break + if text[pos].isalnum() or text[pos].isspace(): + break + pos += 1 + + return pos + + def add_word_and_check_complete(self, word: str) -> bool: + """Record a spoken word from a word-timestamp event. + + Normalizes ``word``, appends it to the running total, and checks whether + all expected alphanumeric characters have been covered. + + Before advancing, checks whether the word belongs to this frame via + ``word_belongs_here``. If it does not (e.g. the TTS provider silently + dropped a word-timestamp), the slot is force-completed: the remaining + unspoken text from ``tts_text`` is stored in ``_frame_word`` so a + TTSTextFrame can still be emitted for the dropped portion, all remaining + ``llm_text`` is consumed, and the entire incoming word is set as overflow + so the caller's overflow path routes it to the next slot unchanged. + + If ``llm_text`` was provided at construction time, also advances the LLM + cursor by the same number of alphanumeric chars consumed from this word and + stores the corresponding LLM span in ``_llm_consumed``. When this word + completes the frame, the entire remaining LLM text (including any closing + tags) is consumed so nothing is lost. + + If the word overshoots the expected length (overflow), the raw suffix of + the word (everything after the last char belonging to this frame) is stored + in ``_overflow_word``, so the caller can attribute it to the next + AggregatedTextFrame. + + Args: + word: A single word token returned by the TTS service. TTS services that + emit spaces and punctuation as separate tokens (e.g. Inworld) must + pre-merge those tokens into the preceding word before calling this + method (see ``TTSService._merge_punct_tokens``). + + Returns: + True when all expected content has been covered. + """ + normalized = self._normalize(word) + + prev_len = len(self._received) + expected_len = len(self._tts_normalized) + + self._overflow_word = None + self._llm_consumed = None + self._frame_word = None + + if prev_len > expected_len: + logger.warning(f"{self}, trying to add a word in an already complete frame") + return True + + # If the word doesn't match the next expected chars, the TTS provider + # likely dropped a word-timestamp event. Force-complete this slot: emit the + # remaining TTS text as _frame_word so a TTSTextFrame is still produced + # for the unspoken portion, consume all remaining llm_text, and route the + # entire incoming word as overflow for the next slot. + if not self.word_belongs_here(word): + self._frame_word = self._tts_text[self._tts_pos :] + if self._llm_text is not None: + self._llm_consumed = self._llm_text[self._llm_pos :] + self._llm_pos = len(self._llm_text) + # This should not happen: force-complete sweeps all remaining + # llm_text, so the span must contain the frame word. If it + # doesn't, tts_text and llm_text are out of sync in an + # unexpected way — discard rather than returning a corrupt span. + # Also removing punctuation from the frame word to match the + # expected text, since some TTS services may add punctuation to + # the raw text. + word_without_punctuation = self._remove_trailing_punctuation(self._frame_word) + if word_without_punctuation and word_without_punctuation not in self._llm_consumed: + logger.warning( + f"WordCompletionTracker: force-complete llm_consumed {repr(self._llm_consumed)!s} " + f"does not contain frame_word {repr(self._frame_word)!s}, discarding" + ) + self._llm_consumed = None + self._received = self._tts_normalized # force-complete + self._overflow_word = word + return True + + self._received += normalized + + # How many normalized chars from this word belong to the current frame. + chars_for_frame = min(len(normalized), expected_len - prev_len) + + if prev_len + len(normalized) > expected_len: + # This word straddles the frame boundary. Split into: + # - _frame_word: the prefix of `word` up to the split point, used + # for the TTSTextFrame of the current slot. + # - raw overflow word: the raw suffix after the split point, used + # to build a TTSTextFrame attributed to the next AggregatedTextFrame. + split_pos = self._advance_by_alnums(word, 0, chars_for_frame) + self._frame_word = word[:split_pos] + self._overflow_word = word[split_pos:] + else: + # Word fits entirely in this frame. + self._frame_word = word + + # Advance the TTS cursor by the same alnum count so the force-complete + # path knows where in _tts_text to start from. + self._tts_pos = self._advance_by_alnums(self._tts_text, self._tts_pos, chars_for_frame) + + if self._llm_text is not None: + if self.is_complete: + # Consume ALL remaining LLM text: closing tags (e.g. ) + # and any trailing punctuation that the TTS will not send separately. + self._llm_consumed = self._llm_text[self._llm_pos :] + self._llm_pos = len(self._llm_text) + else: + if chars_for_frame == 0: + # Consume exactly the raw word in llm_text, skipping any + # leading spaces that belong to the previous token's span. + start = self._llm_pos + while start < len(self._llm_text) and self._llm_text[start].isspace(): + start += 1 + end = start + len(word) + self._llm_consumed = self._llm_text[start:end] + self._llm_pos = end + else: + # Advance through llm_text by exactly chars_for_frame alphanumeric + # chars. Non-alnum chars (spaces, opening tags) are included in the + # slice, preserving the original formatting for the context. + new_pos = self._advance_by_alnums( + self._llm_text, self._llm_pos, chars_for_frame + ) + self._llm_consumed = self._llm_text[self._llm_pos : new_pos] + self._llm_pos = new_pos + # This should not happen: the LLM cursor is driven by the same + # alnum count as the word stream, so the consumed span must contain + # the frame word. If it doesn't, the cursors drifted out of sync + # in an unexpected way — discard rather than returning a corrupt span. + # Also removing punctuation from the frame word to match the + # expected text, since some TTS services may add punctuation to + # the raw text. + word_without_punctuation = self._remove_trailing_punctuation(self._frame_word) + if word_without_punctuation and self._fold_typography( + word_without_punctuation + ) not in self._fold_typography(self._llm_consumed): + logger.warning( + f"WordCompletionTracker: llm_consumed {repr(self._llm_consumed)!s} " + f"does not contain frame_word {repr(self._frame_word)!s}, discarding" + ) + self._llm_consumed = None + + return self.is_complete + + def word_belongs_here(self, word: str) -> bool: + """Return True if this word plausibly belongs to the remaining TTS text. + + Dispatches to one of two checks depending on whether the word contains + any alphanumeric characters after normalization: + + - Alnum words: prefix-match against the remaining expected chars. + - Symbol/punctuation words (empty after normalization): literal substring + search in the remaining raw TTS text, with a fallback for TTS providers + that substitute Unicode symbols with ASCII punctuation. + + Used to detect when the TTS provider silently dropped a word-timestamp + event: if the incoming word does not match this slot's remaining content, + the caller should force-complete this slot and route the word to the next. + """ + normalized = self._normalize(word) + if normalized: + return self._alnum_word_belongs_here(normalized) + else: + return self._symbol_word_belongs_here(word) + + def _alnum_word_belongs_here(self, normalized: str) -> bool: + """Return True if an alnum-containing word matches this frame's remaining expected chars. + + Accepts both full words and partial tokens — the word belongs here as long + as its normalized characters are a prefix of what is still expected. This + also handles the overflow case where the word is longer than the remaining + content (the excess is detected and split in ``add_word_and_check_complete``). + """ + remaining = self._tts_normalized[len(self._received) :] + if not remaining: + return False + check_len = min(len(normalized), len(remaining)) + return remaining.startswith(normalized[:check_len]) + + def _symbol_word_belongs_here(self, word: str) -> bool: + """Return True if a non-alnum word (emoji, punctuation, symbol) belongs to this frame. + + Two checks are applied in order: + + 1. **Literal substring**: search for the raw word in the remaining TTS text. + ``_advance_by_alnums`` may have already moved ``_tts_pos`` past some trailing + punctuation, so the search window is backed up to include those characters. + + 2. **Symbol substitution fallback**: some TTS providers substitute Unicode symbols + with ASCII punctuation in word-timestamp events (e.g. ElevenLabs reports ``→`` + as ``-``), so check 1 always fails even though the word belongs here. If alnum + content still remains unconsumed and the next non-space character in the TTS + text is itself a non-alnum symbol, accept the word as a substitution. + """ + search_start = self._tts_pos + while search_start > 0: + ch = self._tts_text[search_start - 1] + if ch.isalnum() or ch.isspace() or ch == ">": + break + search_start -= 1 + if word in self._tts_text[search_start:]: + return True + + if len(self._received) >= len(self._tts_normalized): + return False + + pos = self._tts_pos + while pos < len(self._tts_text) and self._tts_text[pos].isspace(): + pos += 1 + return pos < len(self._tts_text) and not self._tts_text[pos].isalnum() + + def get_word_for_frame(self) -> str | None: + """Return the portion of the last word that belongs to this frame. + + - Normal word (no overflow): the full word. + - Straddling word: the prefix up to the frame boundary (e.g. ``"1111"`` + from ``"1111 And"``). + - Force-completed (word didn't belong): the remaining unspoken text from + ``tts_text`` so a TTSTextFrame can still be emitted for the dropped + portion. The incoming word is routed as overflow to the next slot. + """ + return self._frame_word.strip() if self._frame_word else self._frame_word + + def get_overflow_word(self) -> str | None: + """Return the raw suffix of the last word that overflows into the next frame. + + Preserves the original casing and any non-alphanumeric characters so the + overflow TTSTextFrame has natural word text. Returns None when there is no + overflow (the word fit entirely within this frame). + """ + return self._overflow_word.strip() if self._overflow_word else self._overflow_word + + def get_llm_consumed(self) -> str | None: + """Return the LLM text span consumed for the last added word. + + Returns None if no llm_text was provided at construction time. + """ + return self._llm_consumed.strip() if self._llm_consumed else self._llm_consumed + + def get_accumulated_tts_text(self) -> str: + """Return all consumed text from tts_text up to the current cursor position. + + Unlike ``get_word_for_frame()`` (which reflects only the last word), this returns + everything that has been consumed since construction or the last ``reset()``. + """ + return self._tts_text[: self._tts_pos] + + def get_accumulated_llm_text(self) -> str | None: + """Return all consumed text from llm_text up to the current cursor position. + + Unlike ``get_llm_consumed()`` (which reflects only the last word), this returns + everything that has been consumed since construction or the last ``reset()``. + Returns None if no llm_text was provided at construction time. + """ + if self._llm_text is None: + return None + return self._llm_text[: self._llm_pos] + + def get_remaining_tts_text(self) -> str: + """Return the unspoken portion of tts_text, stripped of leading/trailing whitespace. + + This is the text that the TTS provider has not yet confirmed via word-timestamp + events. Useful for force-completing a slot when the audio context ends before all + word-timestamp events have arrived. + """ + return self._tts_text[self._tts_pos :].strip() + + def get_remaining_llm_text(self) -> str | None: + """Return the unspoken portion of llm_text, stripped of leading/trailing whitespace. + + Returns None if no llm_text was provided at construction time. Like + ``get_remaining_tts_text()``, intended for force-completing a slot so that the + conversation context receives the full original text. + """ + if self._llm_text is None: + return None + remaining = self._llm_text[self._llm_pos :].strip() + return remaining if remaining else None + + @property + def is_complete(self) -> bool: + """True when accumulated normalized chars >= expected normalized chars.""" + return len(self._received) >= len(self._tts_normalized) + + def reset(self): + """Reset received word accumulation without changing the expected text.""" + self._received = "" + self._tts_pos = 0 + self._llm_pos = 0 + self._overflow_word = None + self._llm_consumed = None + self._frame_word = None diff --git a/src/pipecat/utils/text/word_timestamp_utils.py b/src/pipecat/utils/text/word_timestamp_utils.py new file mode 100644 index 000000000..283e3e617 --- /dev/null +++ b/src/pipecat/utils/text/word_timestamp_utils.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Utilities for normalizing word-timestamp streams from TTS services.""" + +import re + + +def merge_punct_tokens( + word_times: list[tuple[str, float]], +) -> list[tuple[str, float]]: + """Merge punctuation/space-only tokens into the preceding word. + + Some TTS services (e.g. Inworld) emit spaces and punctuation as separate + word-timestamp tokens rather than attaching them to the adjacent word. + This function collapses those tokens so downstream consumers always receive + words with trailing punctuation already attached — identical to the format + produced by ElevenLabs or Cartesia. + + A token is considered punct/space-only when its text contains no alphanumeric + characters after stripping XML/HTML tags. Such tokens are appended to the + preceding word's text and their timestamp is discarded (the preceding word's + timestamp is kept). Leading punct/space tokens with no preceding word are + silently discarded. Every output token is stripped of leading and trailing + whitespace (spaces, tabs, newlines). + + Args: + word_times: Raw list of ``(word, timestamp)`` pairs from the TTS service. + + Returns: + Merged list where every entry contains at least one alphanumeric character + and has no leading or trailing whitespace. + + Example:: + + merge_punct_tokens([("questions", 1.0), (", ", 1.2), ("explain", 1.4)]) + # → [("questions,", 1.0), ("explain", 1.4)] + """ + merged: list[tuple[str, float]] = [] + for word, ts in word_times: + stripped = re.sub(r"<[^>]+>", "", word) + has_alnum = any(c.isalnum() for c in stripped) + if not has_alnum: + if merged: + prev_word, prev_ts = merged[-1] + merged[-1] = (prev_word + word, prev_ts) + # else: leading punct/space with no preceding word → discard + else: + merged.append((word, ts)) + return [(word.strip(), ts) for word, ts in merged]