fix: preserve raw_text through TTS pipeline for correct LLM context attribution
TTSTextFrame entries were losing their original text structure when word timestamps were enabled. AggregatedTextFrame now carries a raw_text field with the original LLM-produced text (including pattern delimiters such as <card>...</card>). The assistant context receives properly-tagged content rather than the cleaned words returned by the TTS provider. Also handles words that straddle two sentence boundaries by splitting and attributing each part to its correct source frame.
This commit is contained in:
@@ -383,10 +383,14 @@ class AggregatedTextFrame(TextFrame):
|
||||
Parameters:
|
||||
aggregated_by: Method used to aggregate the text frames.
|
||||
context_id: Unique identifier for the TTS context that generated this text.
|
||||
raw_text: The full matched text including start/end pattern delimiters, set when
|
||||
this frame was produced from a PatternMatch (e.g. a ``<code>...</code>`` block).
|
||||
None for ordinary sentence aggregations.
|
||||
"""
|
||||
|
||||
aggregated_by: AggregationType | str
|
||||
context_id: str | None = None
|
||||
raw_text: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -25,6 +25,7 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
|
||||
from pipecat.audio.vad.vad_controller import VADController
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AssistantImageRawFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
@@ -1496,9 +1497,14 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if len(frame.text) == 0:
|
||||
return
|
||||
|
||||
text = (
|
||||
frame.raw_text
|
||||
if isinstance(frame, AggregatedTextFrame) and frame.raw_text
|
||||
else frame.text
|
||||
)
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
@@ -85,7 +86,11 @@ class LLMTextProcessor(FrameProcessor):
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
raw_text=aggregation.full_match
|
||||
if isinstance(aggregation, PatternMatch)
|
||||
else aggregation.text,
|
||||
)
|
||||
out_frame.append_to_context = True
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
@@ -96,6 +101,9 @@ class LLMTextProcessor(FrameProcessor):
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=remaining.text,
|
||||
aggregated_by=remaining.type,
|
||||
raw_text=remaining.full_match
|
||||
if isinstance(remaining, PatternMatch)
|
||||
else remaining.text,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
@@ -50,9 +50,13 @@ from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import TTSSettings, is_given
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.context.aggregated_frame_sequencer import AggregatedFrameSequencer
|
||||
from pipecat.utils.context.word_completion_tracker import WordCompletionTracker
|
||||
from pipecat.utils.frame_queue import FrameQueue
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
from pipecat.utils.text.word_timestamp_utils import merge_punct_tokens
|
||||
from pipecat.utils.time import seconds_to_nanoseconds
|
||||
|
||||
|
||||
@@ -97,7 +101,6 @@ class _WordTimestampEntry:
|
||||
word: str
|
||||
timestamp: float
|
||||
context_id: str
|
||||
includes_inter_frame_spaces: bool | None = None
|
||||
|
||||
|
||||
class TTSService(AIService):
|
||||
@@ -289,6 +292,16 @@ class TTSService(AIService):
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: str | None = transport_destination
|
||||
|
||||
# Ordered sequence of every AggregatedTextFrame slot that passes through
|
||||
# _push_tts_frames (both spoken and skipped). Skipped frames are held here
|
||||
# until all preceding spoken slots are complete, then flushed downstream so
|
||||
# their append_to_context=True arrives at the assistant aggregator in the
|
||||
# correct order relative to the TTSTextFrames from spoken sentences.
|
||||
# Tracks all AggregatedTextFrame slots (spoken and skipped) in order.
|
||||
# Skipped frames are held until preceding spoken slots complete, ensuring
|
||||
# append_to_context=True reaches the assistant aggregator in the right order.
|
||||
self._aggregated_frame_sequencer = AggregatedFrameSequencer(name=str(self))
|
||||
|
||||
self._resampler = create_stream_resampler()
|
||||
|
||||
self._processing_text: bool = False
|
||||
@@ -690,7 +703,15 @@ class TTSService(AIService):
|
||||
# Stop the aggregation metric (no-op if already stopped on first sentence).
|
||||
await self.stop_text_aggregation_metrics()
|
||||
if remaining:
|
||||
await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type))
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(
|
||||
remaining.text,
|
||||
remaining.type,
|
||||
raw_text=remaining.full_match
|
||||
if isinstance(remaining, PatternMatch)
|
||||
else remaining.text,
|
||||
)
|
||||
)
|
||||
|
||||
# We pause processing incoming frames if the LLM response included
|
||||
# text (it might be that it's only a function calling response). We
|
||||
@@ -733,7 +754,7 @@ class TTSService(AIService):
|
||||
push_assistant_aggregation = frame.append_to_context and not self._llm_response_started
|
||||
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(frame.text, AggregationType.SENTENCE),
|
||||
AggregatedTextFrame(frame.text, AggregationType.SENTENCE, raw_text=frame.text),
|
||||
append_tts_text_to_context=frame.append_to_context,
|
||||
push_assistant_aggregation=push_assistant_aggregation,
|
||||
)
|
||||
@@ -887,6 +908,7 @@ class TTSService(AIService):
|
||||
self._llm_response_started = False
|
||||
self._streamed_text = ""
|
||||
self._text_aggregation_metrics_started = False
|
||||
self._aggregated_frame_sequencer.clear() # discard all pending slots on interruption
|
||||
await self.reset_word_timestamps()
|
||||
|
||||
await self._stop_audio_context_task()
|
||||
@@ -930,9 +952,23 @@ class TTSService(AIService):
|
||||
if aggregate.type != AggregationType.TOKEN:
|
||||
# Stop the aggregation metric on the first sentence only.
|
||||
await self.stop_text_aggregation_metrics()
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(aggregate.text, aggregate.type), includes_inter_frame_spaces
|
||||
raw_text = (
|
||||
aggregate.full_match if isinstance(aggregate, PatternMatch) else aggregate.text
|
||||
)
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(aggregate.text, aggregate.type, raw_text=raw_text),
|
||||
includes_inter_frame_spaces,
|
||||
)
|
||||
|
||||
async def _push_frame_respecting_previous_aggregated_frame(
|
||||
self, frame: AggregatedTextFrame, context_id: str
|
||||
):
|
||||
# Enqueue the skipped frame; returns it immediately if no spoken slot
|
||||
# precedes it, or holds it until the sequencer can flush it in order.
|
||||
for f in self._aggregated_frame_sequencer.register_skipped(
|
||||
frame, context_id, self._transport_destination
|
||||
):
|
||||
await self.push_frame(f)
|
||||
|
||||
async def _push_tts_frames(
|
||||
self,
|
||||
@@ -944,10 +980,13 @@ class TTSService(AIService):
|
||||
type = src_frame.aggregated_by
|
||||
text = src_frame.text
|
||||
|
||||
# Create context ID and store metadata
|
||||
context_id = self.create_context_id()
|
||||
|
||||
# Skip sending to TTS if the aggregation type is in the skip list. Simply
|
||||
# push the original frame downstream.
|
||||
if type in self._skip_aggregator_types:
|
||||
await self.push_frame(src_frame)
|
||||
await self._push_frame_respecting_previous_aggregated_frame(src_frame, context_id)
|
||||
return
|
||||
|
||||
# Whitespace gating depends on aggregation mode:
|
||||
@@ -998,9 +1037,6 @@ class TTSService(AIService):
|
||||
await self.stop_processing_metrics()
|
||||
return
|
||||
|
||||
# Create context ID and store metadata
|
||||
context_id = self.create_context_id()
|
||||
|
||||
# To support use cases that may want to know the text before it's spoken, we
|
||||
# push the AggregatedTextFrame version before transforming and sending to TTS.
|
||||
# However, we do not want to add this text to the assistant context until it
|
||||
@@ -1045,6 +1081,21 @@ class TTSService(AIService):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.append_to_audio_context(context_id, TTSStartedFrame(context_id=context_id))
|
||||
|
||||
# Register this spoken frame so the sequencer can track its completion
|
||||
# and unblock any skipped frames queued behind it. Word-timestamp services
|
||||
# complete the slot via process_word; push_text_frames services complete it
|
||||
# below after the TTSTextFrame is appended to the audio context.
|
||||
self._aggregated_frame_sequencer.register_spoken(
|
||||
src_frame,
|
||||
context_id,
|
||||
tracker=WordCompletionTracker(
|
||||
prepared_text, llm_text=src_frame.raw_text or src_frame.text
|
||||
)
|
||||
if not self._push_text_frames
|
||||
else None,
|
||||
append_to_context=self._tts_contexts[context_id].append_to_context,
|
||||
)
|
||||
|
||||
await self.tts_process_generator(context_id, self.run_tts(prepared_text, context_id))
|
||||
|
||||
if not self._is_streaming_tokens:
|
||||
@@ -1066,6 +1117,10 @@ class TTSService(AIService):
|
||||
frame.append_to_context = append_tts_text_to_context
|
||||
# Appending to the context, so it preserves the ordering.
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
# TTSTextFrame is queued; mark the spoken slot complete so any skipped
|
||||
# frames (e.g. code blocks) waiting behind it can be flushed in order.
|
||||
for f in self._aggregated_frame_sequencer.complete_spoken_slot():
|
||||
await self.push_frame(f)
|
||||
|
||||
async def tts_process_generator(
|
||||
self, context_id: str, generator: AsyncGenerator[Frame | None, None]
|
||||
@@ -1114,10 +1169,8 @@ class TTSService(AIService):
|
||||
if self._initial_word_times:
|
||||
cached = self._initial_word_times.copy()
|
||||
self._initial_word_times = []
|
||||
for word, timestamp_seconds, ctx_id, ifs in cached:
|
||||
await self._add_word_timestamps(
|
||||
[(word, timestamp_seconds)], ctx_id, includes_inter_frame_spaces=ifs
|
||||
)
|
||||
for word, timestamp_seconds, ctx_id in cached:
|
||||
await self._add_word_timestamps([(word, timestamp_seconds)], ctx_id)
|
||||
|
||||
async def reset_word_timestamps(self):
|
||||
"""Reset word timestamp tracking."""
|
||||
@@ -1139,6 +1192,11 @@ class TTSService(AIService):
|
||||
playback order by _handle_audio_context. Otherwise they are processed immediately
|
||||
via _add_word_timestamps.
|
||||
|
||||
When ``includes_inter_frame_spaces`` is True (e.g. Inworld TTS), punctuation and
|
||||
space-only tokens are merged into the preceding word via ``_merge_punct_tokens``
|
||||
before queuing, so the tracker always receives words with trailing punctuation
|
||||
already attached. ``includes_inter_frame_spaces`` is reset to None after merging.
|
||||
|
||||
Args:
|
||||
word_times: List of (word, timestamp) tuples where timestamp is in seconds.
|
||||
context_id: Unique identifier for the TTS context.
|
||||
@@ -1147,29 +1205,22 @@ class TTSService(AIService):
|
||||
consumers must not inject additional spaces between tokens. None leaves
|
||||
the frame's own default unchanged.
|
||||
"""
|
||||
if includes_inter_frame_spaces:
|
||||
word_times = merge_punct_tokens(word_times)
|
||||
|
||||
if context_id and self.audio_context_available(context_id):
|
||||
for word, timestamp in word_times:
|
||||
await self.append_to_audio_context(
|
||||
context_id,
|
||||
_WordTimestampEntry(
|
||||
word=word,
|
||||
timestamp=timestamp,
|
||||
context_id=context_id,
|
||||
includes_inter_frame_spaces=includes_inter_frame_spaces,
|
||||
),
|
||||
_WordTimestampEntry(word=word, timestamp=timestamp, context_id=context_id),
|
||||
)
|
||||
else:
|
||||
await self._add_word_timestamps(
|
||||
word_times=word_times,
|
||||
context_id=context_id,
|
||||
includes_inter_frame_spaces=includes_inter_frame_spaces,
|
||||
)
|
||||
await self._add_word_timestamps(word_times=word_times, context_id=context_id)
|
||||
|
||||
async def _add_word_timestamps(
|
||||
self,
|
||||
word_times: list[tuple[str, float]],
|
||||
context_id: str | None = None,
|
||||
includes_inter_frame_spaces: bool | None = None,
|
||||
):
|
||||
"""Process word timestamps directly, building and pushing TTSTextFrames inline.
|
||||
|
||||
@@ -1185,19 +1236,15 @@ class TTSService(AIService):
|
||||
ts_ns = seconds_to_nanoseconds(timestamp)
|
||||
if self._initial_word_timestamp == -1:
|
||||
# Cache until we have audio and can compute PTS.
|
||||
self._initial_word_times.append(
|
||||
(word, timestamp, context_id, includes_inter_frame_spaces)
|
||||
)
|
||||
self._initial_word_times.append((word, timestamp, context_id))
|
||||
else:
|
||||
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
|
||||
if includes_inter_frame_spaces is not None:
|
||||
frame.includes_inter_frame_spaces = includes_inter_frame_spaces
|
||||
frame.pts = self._initial_word_timestamp + ts_ns
|
||||
frame.context_id = context_id
|
||||
if context_id in self._tts_contexts:
|
||||
frame.append_to_context = self._tts_contexts[context_id].append_to_context
|
||||
self._word_last_pts = frame.pts
|
||||
await self.push_frame(frame)
|
||||
pts = self._initial_word_timestamp + ts_ns
|
||||
# Build TTSTextFrame(s) for this word token, advancing the active
|
||||
# slot's tracker and flushing any skipped frames now unblocked.
|
||||
for f in self._aggregated_frame_sequencer.process_word(word, pts, context_id):
|
||||
if isinstance(f, TTSTextFrame):
|
||||
self._word_last_pts = f.pts
|
||||
await self.push_frame(f)
|
||||
|
||||
#
|
||||
# Audio context methods (active when using websocket-based TTS with context management)
|
||||
@@ -1382,6 +1429,18 @@ class TTSService(AIService):
|
||||
frame.pts = self._word_last_pts
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _apply_force_complete(self):
|
||||
"""Force-complete all incomplete spoken slots and push any unblocked skipped frames.
|
||||
|
||||
Called at end-of-context to handle TTS providers that silently drop word-timestamp
|
||||
events. Emits a TTSTextFrame for any remaining unspoken text, then flushes skipped
|
||||
frames that were blocked by those incomplete slots.
|
||||
"""
|
||||
for f in self._aggregated_frame_sequencer.force_complete(self._word_last_pts):
|
||||
if isinstance(f, TTSTextFrame):
|
||||
self._word_last_pts = f.pts
|
||||
await self.push_frame(f)
|
||||
|
||||
async def _handle_audio_context(self, context_id: str):
|
||||
"""Process items from an audio context queue until it is exhausted."""
|
||||
queue = self._audio_contexts[context_id]
|
||||
@@ -1402,7 +1461,6 @@ class TTSService(AIService):
|
||||
await self._add_word_timestamps(
|
||||
[(frame.word, frame.timestamp)],
|
||||
frame.context_id,
|
||||
includes_inter_frame_spaces=frame.includes_inter_frame_spaces,
|
||||
)
|
||||
continue
|
||||
elif isinstance(frame, TTSAudioRawFrame):
|
||||
@@ -1416,6 +1474,9 @@ class TTSService(AIService):
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
should_push_stop_frame = self._push_stop_frames
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
# Checking if we have any remaining spoken slots before pushing the TTSStoppedFrame
|
||||
await self._apply_force_complete()
|
||||
|
||||
should_push_stop_frame = False
|
||||
# Setting the last word timestamp as the TTSStoppedFrame PTS
|
||||
if not frame.pts:
|
||||
@@ -1433,8 +1494,11 @@ class TTSService(AIService):
|
||||
should_push_stop_frame = False
|
||||
break
|
||||
|
||||
await self._apply_force_complete()
|
||||
|
||||
if should_push_stop_frame and self._push_stop_frames:
|
||||
await self.push_frame(TTSStoppedFrame(context_id=context_id))
|
||||
|
||||
await self._maybe_reset_word_timestamps()
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
|
||||
490
src/pipecat/utils/context/word_completion_tracker.py
Normal file
490
src/pipecat/utils/context/word_completion_tracker.py
Normal file
@@ -0,0 +1,490 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Word completion tracker for TTS context ordering."""
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class WordCompletionTracker:
|
||||
"""Tracks whether all words from a source AggregatedTextFrame have been spoken.
|
||||
|
||||
Compares normalized alphanumeric character counts between the TTS text and
|
||||
accumulated spoken words, making the check robust to punctuation, spacing,
|
||||
and XML/HTML tags (e.g. SSML tags like ``<spell>...</spell>`` returned by some
|
||||
TTS providers in word-timestamp events).
|
||||
|
||||
When ``llm_text`` is provided (e.g. the original pattern-matched text including
|
||||
delimiters like ``<card>4111 1111 1111 1111</card>``), the tracker additionally
|
||||
maps each spoken word back to its corresponding span in that LLM text. This
|
||||
lets callers attach the original text to ``TTSTextFrame`` entries so the
|
||||
conversation context receives properly-tagged content rather than the cleaned
|
||||
words received from the TTS provider.
|
||||
|
||||
Background: TTS providers apply their own SSML tags to the text before
|
||||
synthesis and return word-timestamp events containing the raw spoken words
|
||||
(e.g. ``"4111"``, ``"1111"``). Without LLM-text tracking, the conversation
|
||||
context would only see those cleaned words and lose the original structure
|
||||
(e.g. ``<card>4111 1111 1111 1111</card>``). By mapping normalized char counts
|
||||
back to positions in ``llm_text``, each TTSTextFrame can carry the exact span
|
||||
of original text it represents.
|
||||
|
||||
Overflow handling: TTS providers sometimes return a single word token that
|
||||
spans the boundary between two AggregatedTextFrames (e.g. ``"1111</spell>And"``
|
||||
when one frame ends with ``1111</card>`` and the next begins with ``And``). The
|
||||
tracker detects this and exposes the raw overflow suffix via ``get_overflow_word()``,
|
||||
so callers can feed the remainder into the next frame's tracker and emit a
|
||||
correctly-attributed TTSTextFrame for each part.
|
||||
|
||||
Example::
|
||||
|
||||
tracker = WordCompletionTracker("Hello, world!")
|
||||
tracker.add_word_and_check_complete("Hello") # False
|
||||
tracker.add_word_and_check_complete("world") # True — normalized "helloworld" >= "helloworld"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tts_text: str,
|
||||
llm_text: str | None = None,
|
||||
):
|
||||
"""Initialize the tracker with the text of the frame being spoken.
|
||||
|
||||
Args:
|
||||
tts_text: Full text of the AggregatedTextFrame sent to TTS (may include
|
||||
TTS-specific SSML tags). Used for normalized char-count completion
|
||||
tracking and as the cursor reference for the TTS word stream.
|
||||
llm_text: Original LLM-produced text including pattern delimiters (e.g.
|
||||
``<card>4111 1111 1111 1111</card>``). When provided, each
|
||||
``add_word_and_check_complete`` call also returns the corresponding
|
||||
LLM span via ``get_llm_consumed()``. Both texts normalize to the
|
||||
same alphanumeric sequence, so the same char-count cursor drives
|
||||
position tracking in both.
|
||||
"""
|
||||
self._tts_normalized = self._normalize(tts_text)
|
||||
self._received = ""
|
||||
|
||||
# _tts_text is the original tts_text before normalization.
|
||||
# _tts_pos is a cursor into it, advanced by the same alnum count
|
||||
# as the TTS word stream, so the force-complete path can emit the remaining
|
||||
# unspoken text as a TTSTextFrame instead of silently dropping it.
|
||||
self._tts_text = tts_text
|
||||
self._tts_pos = 0
|
||||
|
||||
# _llm_text is the original LLM-produced text (with pattern delimiters like
|
||||
# <card>...</card>). We track _llm_pos as a cursor into it, advancing
|
||||
# by the same number of alphanumeric chars consumed from the TTS word stream.
|
||||
self._llm_text = llm_text
|
||||
self._llm_pos = 0
|
||||
|
||||
self._overflow_word: str | None = None
|
||||
self._llm_consumed: str | None = None
|
||||
self._frame_word: str | None = None
|
||||
logger.debug(f"WordCompletionTracker: {self._tts_normalized}")
|
||||
|
||||
@staticmethod
|
||||
def _normalize(text: str) -> str:
|
||||
"""Strip XML/HTML tags then keep only lowercase alphanumeric characters.
|
||||
|
||||
Accented letters (e.g. ã, é) are reduced to their base letter so TTS output
|
||||
can be matched against LLM text even when the provider strips diacritics.
|
||||
Non-Latin scripts (CJK, Hangul) are kept as-is — each original character
|
||||
contributes exactly one char to the result, keeping normalized length in sync
|
||||
with raw alnum counts used by _advance_by_alnums.
|
||||
"""
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
result = []
|
||||
for char in text:
|
||||
# Ignore punctuation, spaces, emojis, etc.
|
||||
# Keep only letters and numbers.
|
||||
if not char.isalnum():
|
||||
continue
|
||||
# NFD decomposes accented characters into:
|
||||
# é -> e + ◌́
|
||||
# ã -> a + ◌̃
|
||||
#
|
||||
# Non-accented characters usually stay unchanged.
|
||||
nfd = unicodedata.normalize("NFD", char)
|
||||
# Unicode category "Mn" means:
|
||||
# Mark, Nonspacing
|
||||
#
|
||||
# These are combining accent marks that modify
|
||||
# the previous character but are not standalone.
|
||||
#
|
||||
# Example:
|
||||
# "é" becomes:
|
||||
# nfd[0] = "e"
|
||||
# nfd[1] = "◌́" (category = "Mn")
|
||||
#
|
||||
# If the second character is a combining accent,
|
||||
# keep only the base letter.
|
||||
if len(nfd) >= 2 and unicodedata.category(nfd[1]) == "Mn":
|
||||
# Accented letter: keep the base character only (drops the combining mark).
|
||||
result.append(nfd[0].lower())
|
||||
else:
|
||||
# Regular ASCII, numbers, CJK, Hangul, etc.
|
||||
# are kept unchanged (except lowercase conversion).
|
||||
result.append(char.lower())
|
||||
return "".join(result)
|
||||
|
||||
# Typographic variants that LLMs commonly emit but TTS services normalize away.
|
||||
_TYPOGRAPHY_FOLD = str.maketrans(
|
||||
{
|
||||
"‘": "'", # ' LEFT SINGLE QUOTATION MARK
|
||||
"’": "'", # ' RIGHT SINGLE QUOTATION MARK
|
||||
"ʼ": "'", # ʼ MODIFIER LETTER APOSTROPHE
|
||||
"“": '"', # " LEFT DOUBLE QUOTATION MARK
|
||||
"”": '"', # " RIGHT DOUBLE QUOTATION MARK
|
||||
"–": "-", # – EN DASH
|
||||
"—": "-", # — EM DASH
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fold_typography(text: str) -> str:
|
||||
"""Replace typographic punctuation variants with their ASCII equivalents."""
|
||||
return text.translate(WordCompletionTracker._TYPOGRAPHY_FOLD)
|
||||
|
||||
@staticmethod
|
||||
def _remove_trailing_punctuation(text: str) -> str:
|
||||
"""Remove punctuation only at the very end of the given text."""
|
||||
i = len(text)
|
||||
while i > 0 and unicodedata.category(text[i - 1]).startswith("P"):
|
||||
i -= 1
|
||||
return text[:i]
|
||||
|
||||
@staticmethod
|
||||
def _advance_by_alnums(text: str, start_pos: int, n: int) -> int:
|
||||
"""Return the position in *text* after advancing past *n* alphanumeric chars.
|
||||
|
||||
Moves through the text one character at a time, counting only alphanumeric
|
||||
characters. XML/HTML tags (``<...>``) are skipped entirely — their content
|
||||
is not counted against the budget, so the returned span includes the full tag.
|
||||
Other non-alphanumeric characters (spaces, punctuation) are also passed over
|
||||
without decrementing the budget.
|
||||
|
||||
After the *n* alnum chars are consumed, advances further past any immediately
|
||||
following punctuation (e.g. the ``,`` in ``"questions,"`` or the ``.`` in
|
||||
``"done."``), stopping before the next space, alnum char, or XML tag.
|
||||
|
||||
Args:
|
||||
text: The source text to scan.
|
||||
start_pos: Starting position in *text*.
|
||||
n: Number of alphanumeric characters to consume.
|
||||
"""
|
||||
pos = start_pos
|
||||
count = 0
|
||||
while pos < len(text) and count < n:
|
||||
if text[pos] == "<":
|
||||
end = text.find(">", pos)
|
||||
pos = end + 1 if end != -1 else pos + 1
|
||||
elif text[pos].isalnum():
|
||||
count += 1
|
||||
pos += 1
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
while pos < len(text):
|
||||
if text[pos] == "<":
|
||||
break
|
||||
if text[pos].isalnum() or text[pos].isspace():
|
||||
break
|
||||
pos += 1
|
||||
|
||||
return pos
|
||||
|
||||
def add_word_and_check_complete(self, word: str) -> bool:
|
||||
"""Record a spoken word from a word-timestamp event.
|
||||
|
||||
Normalizes ``word``, appends it to the running total, and checks whether
|
||||
all expected alphanumeric characters have been covered.
|
||||
|
||||
Before advancing, checks whether the word belongs to this frame via
|
||||
``word_belongs_here``. If it does not (e.g. the TTS provider silently
|
||||
dropped a word-timestamp), the slot is force-completed: the remaining
|
||||
unspoken text from ``tts_text`` is stored in ``_frame_word`` so a
|
||||
TTSTextFrame can still be emitted for the dropped portion, all remaining
|
||||
``llm_text`` is consumed, and the entire incoming word is set as overflow
|
||||
so the caller's overflow path routes it to the next slot unchanged.
|
||||
|
||||
If ``llm_text`` was provided at construction time, also advances the LLM
|
||||
cursor by the same number of alphanumeric chars consumed from this word and
|
||||
stores the corresponding LLM span in ``_llm_consumed``. When this word
|
||||
completes the frame, the entire remaining LLM text (including any closing
|
||||
tags) is consumed so nothing is lost.
|
||||
|
||||
If the word overshoots the expected length (overflow), the raw suffix of
|
||||
the word (everything after the last char belonging to this frame) is stored
|
||||
in ``_overflow_word``, so the caller can attribute it to the next
|
||||
AggregatedTextFrame.
|
||||
|
||||
Args:
|
||||
word: A single word token returned by the TTS service. TTS services that
|
||||
emit spaces and punctuation as separate tokens (e.g. Inworld) must
|
||||
pre-merge those tokens into the preceding word before calling this
|
||||
method (see ``TTSService._merge_punct_tokens``).
|
||||
|
||||
Returns:
|
||||
True when all expected content has been covered.
|
||||
"""
|
||||
normalized = self._normalize(word)
|
||||
|
||||
prev_len = len(self._received)
|
||||
expected_len = len(self._tts_normalized)
|
||||
|
||||
self._overflow_word = None
|
||||
self._llm_consumed = None
|
||||
self._frame_word = None
|
||||
|
||||
if prev_len > expected_len:
|
||||
logger.warning(f"{self}, trying to add a word in an already complete frame")
|
||||
return True
|
||||
|
||||
# If the word doesn't match the next expected chars, the TTS provider
|
||||
# likely dropped a word-timestamp event. Force-complete this slot: emit the
|
||||
# remaining TTS text as _frame_word so a TTSTextFrame is still produced
|
||||
# for the unspoken portion, consume all remaining llm_text, and route the
|
||||
# entire incoming word as overflow for the next slot.
|
||||
if not self.word_belongs_here(word):
|
||||
self._frame_word = self._tts_text[self._tts_pos :]
|
||||
if self._llm_text is not None:
|
||||
self._llm_consumed = self._llm_text[self._llm_pos :]
|
||||
self._llm_pos = len(self._llm_text)
|
||||
# This should not happen: force-complete sweeps all remaining
|
||||
# llm_text, so the span must contain the frame word. If it
|
||||
# doesn't, tts_text and llm_text are out of sync in an
|
||||
# unexpected way — discard rather than returning a corrupt span.
|
||||
# Also removing punctuation from the frame word to match the
|
||||
# expected text, since some TTS services may add punctuation to
|
||||
# the raw text.
|
||||
word_without_punctuation = self._remove_trailing_punctuation(self._frame_word)
|
||||
if word_without_punctuation and word_without_punctuation not in self._llm_consumed:
|
||||
logger.warning(
|
||||
f"WordCompletionTracker: force-complete llm_consumed {repr(self._llm_consumed)!s} "
|
||||
f"does not contain frame_word {repr(self._frame_word)!s}, discarding"
|
||||
)
|
||||
self._llm_consumed = None
|
||||
self._received = self._tts_normalized # force-complete
|
||||
self._overflow_word = word
|
||||
return True
|
||||
|
||||
self._received += normalized
|
||||
|
||||
# How many normalized chars from this word belong to the current frame.
|
||||
chars_for_frame = min(len(normalized), expected_len - prev_len)
|
||||
|
||||
if prev_len + len(normalized) > expected_len:
|
||||
# This word straddles the frame boundary. Split into:
|
||||
# - _frame_word: the prefix of `word` up to the split point, used
|
||||
# for the TTSTextFrame of the current slot.
|
||||
# - raw overflow word: the raw suffix after the split point, used
|
||||
# to build a TTSTextFrame attributed to the next AggregatedTextFrame.
|
||||
split_pos = self._advance_by_alnums(word, 0, chars_for_frame)
|
||||
self._frame_word = word[:split_pos]
|
||||
self._overflow_word = word[split_pos:]
|
||||
else:
|
||||
# Word fits entirely in this frame.
|
||||
self._frame_word = word
|
||||
|
||||
# Advance the TTS cursor by the same alnum count so the force-complete
|
||||
# path knows where in _tts_text to start from.
|
||||
self._tts_pos = self._advance_by_alnums(self._tts_text, self._tts_pos, chars_for_frame)
|
||||
|
||||
if self._llm_text is not None:
|
||||
if self.is_complete:
|
||||
# Consume ALL remaining LLM text: closing tags (e.g. </card>)
|
||||
# and any trailing punctuation that the TTS will not send separately.
|
||||
self._llm_consumed = self._llm_text[self._llm_pos :]
|
||||
self._llm_pos = len(self._llm_text)
|
||||
else:
|
||||
if chars_for_frame == 0:
|
||||
# Consume exactly the raw word in llm_text, skipping any
|
||||
# leading spaces that belong to the previous token's span.
|
||||
start = self._llm_pos
|
||||
while start < len(self._llm_text) and self._llm_text[start].isspace():
|
||||
start += 1
|
||||
end = start + len(word)
|
||||
self._llm_consumed = self._llm_text[start:end]
|
||||
self._llm_pos = end
|
||||
else:
|
||||
# Advance through llm_text by exactly chars_for_frame alphanumeric
|
||||
# chars. Non-alnum chars (spaces, opening tags) are included in the
|
||||
# slice, preserving the original formatting for the context.
|
||||
new_pos = self._advance_by_alnums(
|
||||
self._llm_text, self._llm_pos, chars_for_frame
|
||||
)
|
||||
self._llm_consumed = self._llm_text[self._llm_pos : new_pos]
|
||||
self._llm_pos = new_pos
|
||||
# This should not happen: the LLM cursor is driven by the same
|
||||
# alnum count as the word stream, so the consumed span must contain
|
||||
# the frame word. If it doesn't, the cursors drifted out of sync
|
||||
# in an unexpected way — discard rather than returning a corrupt span.
|
||||
# Also removing punctuation from the frame word to match the
|
||||
# expected text, since some TTS services may add punctuation to
|
||||
# the raw text.
|
||||
word_without_punctuation = self._remove_trailing_punctuation(self._frame_word)
|
||||
if word_without_punctuation and self._fold_typography(
|
||||
word_without_punctuation
|
||||
) not in self._fold_typography(self._llm_consumed):
|
||||
logger.warning(
|
||||
f"WordCompletionTracker: llm_consumed {repr(self._llm_consumed)!s} "
|
||||
f"does not contain frame_word {repr(self._frame_word)!s}, discarding"
|
||||
)
|
||||
self._llm_consumed = None
|
||||
|
||||
return self.is_complete
|
||||
|
||||
def word_belongs_here(self, word: str) -> bool:
|
||||
"""Return True if this word plausibly belongs to the remaining TTS text.
|
||||
|
||||
Dispatches to one of two checks depending on whether the word contains
|
||||
any alphanumeric characters after normalization:
|
||||
|
||||
- Alnum words: prefix-match against the remaining expected chars.
|
||||
- Symbol/punctuation words (empty after normalization): literal substring
|
||||
search in the remaining raw TTS text, with a fallback for TTS providers
|
||||
that substitute Unicode symbols with ASCII punctuation.
|
||||
|
||||
Used to detect when the TTS provider silently dropped a word-timestamp
|
||||
event: if the incoming word does not match this slot's remaining content,
|
||||
the caller should force-complete this slot and route the word to the next.
|
||||
"""
|
||||
normalized = self._normalize(word)
|
||||
if normalized:
|
||||
return self._alnum_word_belongs_here(normalized)
|
||||
else:
|
||||
return self._symbol_word_belongs_here(word)
|
||||
|
||||
def _alnum_word_belongs_here(self, normalized: str) -> bool:
|
||||
"""Return True if an alnum-containing word matches this frame's remaining expected chars.
|
||||
|
||||
Accepts both full words and partial tokens — the word belongs here as long
|
||||
as its normalized characters are a prefix of what is still expected. This
|
||||
also handles the overflow case where the word is longer than the remaining
|
||||
content (the excess is detected and split in ``add_word_and_check_complete``).
|
||||
"""
|
||||
remaining = self._tts_normalized[len(self._received) :]
|
||||
if not remaining:
|
||||
return False
|
||||
check_len = min(len(normalized), len(remaining))
|
||||
return remaining.startswith(normalized[:check_len])
|
||||
|
||||
def _symbol_word_belongs_here(self, word: str) -> bool:
|
||||
"""Return True if a non-alnum word (emoji, punctuation, symbol) belongs to this frame.
|
||||
|
||||
Two checks are applied in order:
|
||||
|
||||
1. **Literal substring**: search for the raw word in the remaining TTS text.
|
||||
``_advance_by_alnums`` may have already moved ``_tts_pos`` past some trailing
|
||||
punctuation, so the search window is backed up to include those characters.
|
||||
|
||||
2. **Symbol substitution fallback**: some TTS providers substitute Unicode symbols
|
||||
with ASCII punctuation in word-timestamp events (e.g. ElevenLabs reports ``→``
|
||||
as ``-``), so check 1 always fails even though the word belongs here. If alnum
|
||||
content still remains unconsumed and the next non-space character in the TTS
|
||||
text is itself a non-alnum symbol, accept the word as a substitution.
|
||||
"""
|
||||
search_start = self._tts_pos
|
||||
while search_start > 0:
|
||||
ch = self._tts_text[search_start - 1]
|
||||
if ch.isalnum() or ch.isspace() or ch == ">":
|
||||
break
|
||||
search_start -= 1
|
||||
if word in self._tts_text[search_start:]:
|
||||
return True
|
||||
|
||||
if len(self._received) >= len(self._tts_normalized):
|
||||
return False
|
||||
|
||||
pos = self._tts_pos
|
||||
while pos < len(self._tts_text) and self._tts_text[pos].isspace():
|
||||
pos += 1
|
||||
return pos < len(self._tts_text) and not self._tts_text[pos].isalnum()
|
||||
|
||||
def get_word_for_frame(self) -> str | None:
|
||||
"""Return the portion of the last word that belongs to this frame.
|
||||
|
||||
- Normal word (no overflow): the full word.
|
||||
- Straddling word: the prefix up to the frame boundary (e.g. ``"1111"``
|
||||
from ``"1111 And"``).
|
||||
- Force-completed (word didn't belong): the remaining unspoken text from
|
||||
``tts_text`` so a TTSTextFrame can still be emitted for the dropped
|
||||
portion. The incoming word is routed as overflow to the next slot.
|
||||
"""
|
||||
return self._frame_word.strip() if self._frame_word else self._frame_word
|
||||
|
||||
def get_overflow_word(self) -> str | None:
|
||||
"""Return the raw suffix of the last word that overflows into the next frame.
|
||||
|
||||
Preserves the original casing and any non-alphanumeric characters so the
|
||||
overflow TTSTextFrame has natural word text. Returns None when there is no
|
||||
overflow (the word fit entirely within this frame).
|
||||
"""
|
||||
return self._overflow_word.strip() if self._overflow_word else self._overflow_word
|
||||
|
||||
def get_llm_consumed(self) -> str | None:
|
||||
"""Return the LLM text span consumed for the last added word.
|
||||
|
||||
Returns None if no llm_text was provided at construction time.
|
||||
"""
|
||||
return self._llm_consumed.strip() if self._llm_consumed else self._llm_consumed
|
||||
|
||||
def get_accumulated_tts_text(self) -> str:
|
||||
"""Return all consumed text from tts_text up to the current cursor position.
|
||||
|
||||
Unlike ``get_word_for_frame()`` (which reflects only the last word), this returns
|
||||
everything that has been consumed since construction or the last ``reset()``.
|
||||
"""
|
||||
return self._tts_text[: self._tts_pos]
|
||||
|
||||
def get_accumulated_llm_text(self) -> str | None:
|
||||
"""Return all consumed text from llm_text up to the current cursor position.
|
||||
|
||||
Unlike ``get_llm_consumed()`` (which reflects only the last word), this returns
|
||||
everything that has been consumed since construction or the last ``reset()``.
|
||||
Returns None if no llm_text was provided at construction time.
|
||||
"""
|
||||
if self._llm_text is None:
|
||||
return None
|
||||
return self._llm_text[: self._llm_pos]
|
||||
|
||||
def get_remaining_tts_text(self) -> str:
|
||||
"""Return the unspoken portion of tts_text, stripped of leading/trailing whitespace.
|
||||
|
||||
This is the text that the TTS provider has not yet confirmed via word-timestamp
|
||||
events. Useful for force-completing a slot when the audio context ends before all
|
||||
word-timestamp events have arrived.
|
||||
"""
|
||||
return self._tts_text[self._tts_pos :].strip()
|
||||
|
||||
def get_remaining_llm_text(self) -> str | None:
|
||||
"""Return the unspoken portion of llm_text, stripped of leading/trailing whitespace.
|
||||
|
||||
Returns None if no llm_text was provided at construction time. Like
|
||||
``get_remaining_tts_text()``, intended for force-completing a slot so that the
|
||||
conversation context receives the full original text.
|
||||
"""
|
||||
if self._llm_text is None:
|
||||
return None
|
||||
remaining = self._llm_text[self._llm_pos :].strip()
|
||||
return remaining if remaining else None
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""True when accumulated normalized chars >= expected normalized chars."""
|
||||
return len(self._received) >= len(self._tts_normalized)
|
||||
|
||||
def reset(self):
|
||||
"""Reset received word accumulation without changing the expected text."""
|
||||
self._received = ""
|
||||
self._tts_pos = 0
|
||||
self._llm_pos = 0
|
||||
self._overflow_word = None
|
||||
self._llm_consumed = None
|
||||
self._frame_word = None
|
||||
53
src/pipecat/utils/text/word_timestamp_utils.py
Normal file
53
src/pipecat/utils/text/word_timestamp_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Utilities for normalizing word-timestamp streams from TTS services."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def merge_punct_tokens(
|
||||
word_times: list[tuple[str, float]],
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Merge punctuation/space-only tokens into the preceding word.
|
||||
|
||||
Some TTS services (e.g. Inworld) emit spaces and punctuation as separate
|
||||
word-timestamp tokens rather than attaching them to the adjacent word.
|
||||
This function collapses those tokens so downstream consumers always receive
|
||||
words with trailing punctuation already attached — identical to the format
|
||||
produced by ElevenLabs or Cartesia.
|
||||
|
||||
A token is considered punct/space-only when its text contains no alphanumeric
|
||||
characters after stripping XML/HTML tags. Such tokens are appended to the
|
||||
preceding word's text and their timestamp is discarded (the preceding word's
|
||||
timestamp is kept). Leading punct/space tokens with no preceding word are
|
||||
silently discarded. Every output token is stripped of leading and trailing
|
||||
whitespace (spaces, tabs, newlines).
|
||||
|
||||
Args:
|
||||
word_times: Raw list of ``(word, timestamp)`` pairs from the TTS service.
|
||||
|
||||
Returns:
|
||||
Merged list where every entry contains at least one alphanumeric character
|
||||
and has no leading or trailing whitespace.
|
||||
|
||||
Example::
|
||||
|
||||
merge_punct_tokens([("questions", 1.0), (", ", 1.2), ("explain", 1.4)])
|
||||
# → [("questions,", 1.0), ("explain", 1.4)]
|
||||
"""
|
||||
merged: list[tuple[str, float]] = []
|
||||
for word, ts in word_times:
|
||||
stripped = re.sub(r"<[^>]+>", "", word)
|
||||
has_alnum = any(c.isalnum() for c in stripped)
|
||||
if not has_alnum:
|
||||
if merged:
|
||||
prev_word, prev_ts = merged[-1]
|
||||
merged[-1] = (prev_word + word, prev_ts)
|
||||
# else: leading punct/space with no preceding word → discard
|
||||
else:
|
||||
merged.append((word, ts))
|
||||
return [(word.strip(), ts) for word, ts in merged]
|
||||
Reference in New Issue
Block a user