Merge pull request #4380 from pipecat-ai/filipi/smart_text
Smart Text Handling
This commit is contained in:
1
changelog/4380.fixed.2.md
Normal file
1
changelog/4380.fixed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `BaseOutputTransport` reordering frames that share the same presentation timestamp. Frames with equal PTS values are now emitted in insertion order, preventing subtle audio/text sequencing bugs when multiple frames arrive at the same time.
|
||||
1
changelog/4380.fixed.3.md
Normal file
1
changelog/4380.fixed.3.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed Cartesia word timestamps leaking SSML tag text (e.g. `<spell>`, `<emotion>`, `<break>`) into word entries. Tags are now stripped before processing, so word-to-text attribution remains accurate when SSML markup is present in the TTS input.
|
||||
1
changelog/4380.fixed.4.md
Normal file
1
changelog/4380.fixed.4.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `TTSTextFrame` entries losing their original text structure when word timestamps are enabled. Each `TTSTextFrame` now carries a `raw_text` field containing the corresponding span of the original LLM-produced text (including pattern delimiters such as `<card>4111 1111 1111 1111</card>`), so the assistant context receives properly-tagged content rather than the cleaned words returned by the TTS provider. Also handles words that straddle two sentence boundaries by splitting them and attributing each part to its correct source frame.
|
||||
1
changelog/4380.fixed.md
Normal file
1
changelog/4380.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed skipped TTS frames (e.g. code blocks filtered via `skip_aggregator_types`) being emitted to the assistant context immediately instead of waiting for preceding spoken frames to finish. They now hold their position in the frame sequence and are flushed only after all earlier spoken sentences are complete, keeping context ordering correct.
|
||||
@@ -383,10 +383,14 @@ class AggregatedTextFrame(TextFrame):
|
||||
Parameters:
|
||||
aggregated_by: Method used to aggregate the text frames.
|
||||
context_id: Unique identifier for the TTS context that generated this text.
|
||||
raw_text: The full matched text including start/end pattern delimiters, set when
|
||||
this frame was produced from a PatternMatch (e.g. a ``<code>...</code>`` block).
|
||||
None for ordinary sentence aggregations.
|
||||
"""
|
||||
|
||||
aggregated_by: AggregationType | str
|
||||
context_id: str | None = None
|
||||
raw_text: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -25,6 +25,7 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
|
||||
from pipecat.audio.vad.vad_controller import VADController
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AssistantImageRawFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
@@ -1496,9 +1497,14 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if len(frame.text) == 0:
|
||||
return
|
||||
|
||||
text = (
|
||||
frame.raw_text
|
||||
if isinstance(frame, AggregatedTextFrame) and frame.raw_text
|
||||
else frame.text
|
||||
)
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
@@ -85,7 +86,11 @@ class LLMTextProcessor(FrameProcessor):
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
raw_text=aggregation.full_match
|
||||
if isinstance(aggregation, PatternMatch)
|
||||
else aggregation.text,
|
||||
)
|
||||
out_frame.append_to_context = True
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
@@ -96,6 +101,9 @@ class LLMTextProcessor(FrameProcessor):
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=remaining.text,
|
||||
aggregated_by=remaining.type,
|
||||
raw_text=remaining.full_match
|
||||
if isinstance(remaining, PatternMatch)
|
||||
else remaining.text,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
@@ -528,6 +528,9 @@ class RTVIObserver(BaseObserver):
|
||||
text = await transform(text, agg_type)
|
||||
|
||||
isTTS = isinstance(frame, TTSTextFrame)
|
||||
if agg_type is not AggregationType.WORD:
|
||||
logger.debug(f"{self} Aggregated LLM text: {text}, {agg_type} spoken:{isTTS}")
|
||||
|
||||
if self._params.bot_output_enabled:
|
||||
message = RTVI.BotOutputMessage(
|
||||
data=RTVI.BotOutputMessageData(text=text, spoken=isTTS, aggregated_by=agg_type)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
@@ -431,10 +432,20 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
base_lang = language.split("-")[0].lower()
|
||||
return base_lang in {"zh", "ja"}
|
||||
|
||||
def _process_word_timestamps_for_language(
|
||||
_CARTESIA_TAG_RE = re.compile(r"</?(?:spell|emotion|break|volume|speed)\b[^>]*>", re.IGNORECASE)
|
||||
|
||||
def _strip_cartesia_tags(self, text: str) -> str:
|
||||
text = self._CARTESIA_TAG_RE.sub(" ", text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
def _normalize_word_timestamps(
|
||||
self, words: list[str], starts: list[float]
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Process word timestamps based on the current language.
|
||||
"""Normalize raw word timestamps from Cartesia before further processing.
|
||||
|
||||
Strips Cartesia SSML tags (spell, emotion, break, volume, speed) from each word
|
||||
and drops entries that become empty after stripping.
|
||||
|
||||
For Chinese and Japanese, Cartesia groups related characters in the same timestamp
|
||||
message.
|
||||
@@ -458,14 +469,18 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
# For Chinese/Japanese, combine all characters in this message into one word
|
||||
# using the first character's start time.
|
||||
if words and starts:
|
||||
combined_word = "".join(words)
|
||||
combined_word = "".join(self._strip_cartesia_tags(w) for w in words)
|
||||
first_start = starts[0]
|
||||
return [(combined_word, first_start)]
|
||||
return [(combined_word, first_start)] if combined_word else []
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
# For non-CJK languages, use as-is
|
||||
return list(zip(words, starts))
|
||||
result = []
|
||||
for word, start in zip(words, starts):
|
||||
cleaned = self._strip_cartesia_tags(word)
|
||||
if cleaned:
|
||||
result.append((cleaned, start))
|
||||
return result
|
||||
|
||||
def _word_timestamps_include_inter_frame_spaces(self) -> bool:
|
||||
"""Whether timestamp text should be treated as carrying its own spacing."""
|
||||
@@ -662,7 +677,7 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
await self.remove_audio_context(ctx_id)
|
||||
elif msg["type"] == "timestamps":
|
||||
# Process the timestamps based on language before adding them
|
||||
processed_timestamps = self._process_word_timestamps_for_language(
|
||||
processed_timestamps = self._normalize_word_timestamps(
|
||||
msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]
|
||||
)
|
||||
await self.add_word_timestamps(
|
||||
|
||||
@@ -50,9 +50,13 @@ from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import TTSSettings, is_given
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.context.aggregated_frame_sequencer import AggregatedFrameSequencer
|
||||
from pipecat.utils.context.word_completion_tracker import WordCompletionTracker
|
||||
from pipecat.utils.frame_queue import FrameQueue
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
from pipecat.utils.text.word_timestamp_utils import merge_punct_tokens
|
||||
from pipecat.utils.time import seconds_to_nanoseconds
|
||||
|
||||
|
||||
@@ -97,7 +101,6 @@ class _WordTimestampEntry:
|
||||
word: str
|
||||
timestamp: float
|
||||
context_id: str
|
||||
includes_inter_frame_spaces: bool | None = None
|
||||
|
||||
|
||||
class TTSService(AIService):
|
||||
@@ -289,6 +292,16 @@ class TTSService(AIService):
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: str | None = transport_destination
|
||||
|
||||
# Ordered sequence of every AggregatedTextFrame slot that passes through
|
||||
# _push_tts_frames (both spoken and skipped). Skipped frames are held here
|
||||
# until all preceding spoken slots are complete, then flushed downstream so
|
||||
# their append_to_context=True arrives at the assistant aggregator in the
|
||||
# correct order relative to the TTSTextFrames from spoken sentences.
|
||||
# Tracks all AggregatedTextFrame slots (spoken and skipped) in order.
|
||||
# Skipped frames are held until preceding spoken slots complete, ensuring
|
||||
# append_to_context=True reaches the assistant aggregator in the right order.
|
||||
self._aggregated_frame_sequencer = AggregatedFrameSequencer(name=str(self))
|
||||
|
||||
self._resampler = create_stream_resampler()
|
||||
|
||||
self._processing_text: bool = False
|
||||
@@ -690,7 +703,15 @@ class TTSService(AIService):
|
||||
# Stop the aggregation metric (no-op if already stopped on first sentence).
|
||||
await self.stop_text_aggregation_metrics()
|
||||
if remaining:
|
||||
await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type))
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(
|
||||
remaining.text,
|
||||
remaining.type,
|
||||
raw_text=remaining.full_match
|
||||
if isinstance(remaining, PatternMatch)
|
||||
else remaining.text,
|
||||
)
|
||||
)
|
||||
|
||||
# We pause processing incoming frames if the LLM response included
|
||||
# text (it might be that it's only a function calling response). We
|
||||
@@ -733,7 +754,7 @@ class TTSService(AIService):
|
||||
push_assistant_aggregation = frame.append_to_context and not self._llm_response_started
|
||||
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(frame.text, AggregationType.SENTENCE),
|
||||
AggregatedTextFrame(frame.text, AggregationType.SENTENCE, raw_text=frame.text),
|
||||
append_tts_text_to_context=frame.append_to_context,
|
||||
push_assistant_aggregation=push_assistant_aggregation,
|
||||
)
|
||||
@@ -887,6 +908,7 @@ class TTSService(AIService):
|
||||
self._llm_response_started = False
|
||||
self._streamed_text = ""
|
||||
self._text_aggregation_metrics_started = False
|
||||
self._aggregated_frame_sequencer.clear() # discard all pending slots on interruption
|
||||
await self.reset_word_timestamps()
|
||||
|
||||
await self._stop_audio_context_task()
|
||||
@@ -930,9 +952,23 @@ class TTSService(AIService):
|
||||
if aggregate.type != AggregationType.TOKEN:
|
||||
# Stop the aggregation metric on the first sentence only.
|
||||
await self.stop_text_aggregation_metrics()
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(aggregate.text, aggregate.type), includes_inter_frame_spaces
|
||||
raw_text = (
|
||||
aggregate.full_match if isinstance(aggregate, PatternMatch) else aggregate.text
|
||||
)
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(aggregate.text, aggregate.type, raw_text=raw_text),
|
||||
includes_inter_frame_spaces,
|
||||
)
|
||||
|
||||
async def _push_frame_respecting_previous_aggregated_frame(
|
||||
self, frame: AggregatedTextFrame, context_id: str
|
||||
):
|
||||
# Enqueue the skipped frame; returns it immediately if no spoken slot
|
||||
# precedes it, or holds it until the sequencer can flush it in order.
|
||||
for f in self._aggregated_frame_sequencer.register_skipped(
|
||||
frame, context_id, self._transport_destination
|
||||
):
|
||||
await self.push_frame(f)
|
||||
|
||||
async def _push_tts_frames(
|
||||
self,
|
||||
@@ -944,10 +980,13 @@ class TTSService(AIService):
|
||||
type = src_frame.aggregated_by
|
||||
text = src_frame.text
|
||||
|
||||
# Create context ID and store metadata
|
||||
context_id = self.create_context_id()
|
||||
|
||||
# Skip sending to TTS if the aggregation type is in the skip list. Simply
|
||||
# push the original frame downstream.
|
||||
if type in self._skip_aggregator_types:
|
||||
await self.push_frame(src_frame)
|
||||
await self._push_frame_respecting_previous_aggregated_frame(src_frame, context_id)
|
||||
return
|
||||
|
||||
# Whitespace gating depends on aggregation mode:
|
||||
@@ -998,9 +1037,6 @@ class TTSService(AIService):
|
||||
await self.stop_processing_metrics()
|
||||
return
|
||||
|
||||
# Create context ID and store metadata
|
||||
context_id = self.create_context_id()
|
||||
|
||||
# To support use cases that may want to know the text before it's spoken, we
|
||||
# push the AggregatedTextFrame version before transforming and sending to TTS.
|
||||
# However, we do not want to add this text to the assistant context until it
|
||||
@@ -1045,6 +1081,21 @@ class TTSService(AIService):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.append_to_audio_context(context_id, TTSStartedFrame(context_id=context_id))
|
||||
|
||||
# Register this spoken frame so the sequencer can track its completion
|
||||
# and unblock any skipped frames queued behind it. Word-timestamp services
|
||||
# complete the slot via process_word; push_text_frames services complete it
|
||||
# below after the TTSTextFrame is appended to the audio context.
|
||||
self._aggregated_frame_sequencer.register_spoken(
|
||||
src_frame,
|
||||
context_id,
|
||||
tracker=WordCompletionTracker(
|
||||
prepared_text, llm_text=src_frame.raw_text or src_frame.text
|
||||
)
|
||||
if not self._push_text_frames
|
||||
else None,
|
||||
append_to_context=self._tts_contexts[context_id].append_to_context,
|
||||
)
|
||||
|
||||
await self.tts_process_generator(context_id, self.run_tts(prepared_text, context_id))
|
||||
|
||||
if not self._is_streaming_tokens:
|
||||
@@ -1066,6 +1117,10 @@ class TTSService(AIService):
|
||||
frame.append_to_context = append_tts_text_to_context
|
||||
# Appending to the context, so it preserves the ordering.
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
# TTSTextFrame is queued; mark the spoken slot complete so any skipped
|
||||
# frames (e.g. code blocks) waiting behind it can be flushed in order.
|
||||
for f in self._aggregated_frame_sequencer.complete_spoken_slot():
|
||||
await self.push_frame(f)
|
||||
|
||||
async def tts_process_generator(
|
||||
self, context_id: str, generator: AsyncGenerator[Frame | None, None]
|
||||
@@ -1114,10 +1169,8 @@ class TTSService(AIService):
|
||||
if self._initial_word_times:
|
||||
cached = self._initial_word_times.copy()
|
||||
self._initial_word_times = []
|
||||
for word, timestamp_seconds, ctx_id, ifs in cached:
|
||||
await self._add_word_timestamps(
|
||||
[(word, timestamp_seconds)], ctx_id, includes_inter_frame_spaces=ifs
|
||||
)
|
||||
for word, timestamp_seconds, ctx_id in cached:
|
||||
await self._add_word_timestamps([(word, timestamp_seconds)], ctx_id)
|
||||
|
||||
async def reset_word_timestamps(self):
|
||||
"""Reset word timestamp tracking."""
|
||||
@@ -1139,6 +1192,11 @@ class TTSService(AIService):
|
||||
playback order by _handle_audio_context. Otherwise they are processed immediately
|
||||
via _add_word_timestamps.
|
||||
|
||||
When ``includes_inter_frame_spaces`` is True (e.g. Inworld TTS), punctuation and
|
||||
space-only tokens are merged into the preceding word via ``_merge_punct_tokens``
|
||||
before queuing, so the tracker always receives words with trailing punctuation
|
||||
already attached. ``includes_inter_frame_spaces`` is reset to None after merging.
|
||||
|
||||
Args:
|
||||
word_times: List of (word, timestamp) tuples where timestamp is in seconds.
|
||||
context_id: Unique identifier for the TTS context.
|
||||
@@ -1147,29 +1205,22 @@ class TTSService(AIService):
|
||||
consumers must not inject additional spaces between tokens. None leaves
|
||||
the frame's own default unchanged.
|
||||
"""
|
||||
if includes_inter_frame_spaces:
|
||||
word_times = merge_punct_tokens(word_times)
|
||||
|
||||
if context_id and self.audio_context_available(context_id):
|
||||
for word, timestamp in word_times:
|
||||
await self.append_to_audio_context(
|
||||
context_id,
|
||||
_WordTimestampEntry(
|
||||
word=word,
|
||||
timestamp=timestamp,
|
||||
context_id=context_id,
|
||||
includes_inter_frame_spaces=includes_inter_frame_spaces,
|
||||
),
|
||||
_WordTimestampEntry(word=word, timestamp=timestamp, context_id=context_id),
|
||||
)
|
||||
else:
|
||||
await self._add_word_timestamps(
|
||||
word_times=word_times,
|
||||
context_id=context_id,
|
||||
includes_inter_frame_spaces=includes_inter_frame_spaces,
|
||||
)
|
||||
await self._add_word_timestamps(word_times=word_times, context_id=context_id)
|
||||
|
||||
async def _add_word_timestamps(
|
||||
self,
|
||||
word_times: list[tuple[str, float]],
|
||||
context_id: str | None = None,
|
||||
includes_inter_frame_spaces: bool | None = None,
|
||||
):
|
||||
"""Process word timestamps directly, building and pushing TTSTextFrames inline.
|
||||
|
||||
@@ -1185,19 +1236,15 @@ class TTSService(AIService):
|
||||
ts_ns = seconds_to_nanoseconds(timestamp)
|
||||
if self._initial_word_timestamp == -1:
|
||||
# Cache until we have audio and can compute PTS.
|
||||
self._initial_word_times.append(
|
||||
(word, timestamp, context_id, includes_inter_frame_spaces)
|
||||
)
|
||||
self._initial_word_times.append((word, timestamp, context_id))
|
||||
else:
|
||||
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
|
||||
if includes_inter_frame_spaces is not None:
|
||||
frame.includes_inter_frame_spaces = includes_inter_frame_spaces
|
||||
frame.pts = self._initial_word_timestamp + ts_ns
|
||||
frame.context_id = context_id
|
||||
if context_id in self._tts_contexts:
|
||||
frame.append_to_context = self._tts_contexts[context_id].append_to_context
|
||||
self._word_last_pts = frame.pts
|
||||
await self.push_frame(frame)
|
||||
pts = self._initial_word_timestamp + ts_ns
|
||||
# Build TTSTextFrame(s) for this word token, advancing the active
|
||||
# slot's tracker and flushing any skipped frames now unblocked.
|
||||
for f in self._aggregated_frame_sequencer.process_word(word, pts, context_id):
|
||||
if isinstance(f, TTSTextFrame):
|
||||
self._word_last_pts = f.pts
|
||||
await self.push_frame(f)
|
||||
|
||||
#
|
||||
# Audio context methods (active when using websocket-based TTS with context management)
|
||||
@@ -1382,6 +1429,18 @@ class TTSService(AIService):
|
||||
frame.pts = self._word_last_pts
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _apply_force_complete(self):
|
||||
"""Force-complete all incomplete spoken slots and push any unblocked skipped frames.
|
||||
|
||||
Called at end-of-context to handle TTS providers that silently drop word-timestamp
|
||||
events. Emits a TTSTextFrame for any remaining unspoken text, then flushes skipped
|
||||
frames that were blocked by those incomplete slots.
|
||||
"""
|
||||
for f in self._aggregated_frame_sequencer.force_complete(self._word_last_pts):
|
||||
if isinstance(f, TTSTextFrame):
|
||||
self._word_last_pts = f.pts
|
||||
await self.push_frame(f)
|
||||
|
||||
async def _handle_audio_context(self, context_id: str):
|
||||
"""Process items from an audio context queue until it is exhausted."""
|
||||
queue = self._audio_contexts[context_id]
|
||||
@@ -1402,7 +1461,6 @@ class TTSService(AIService):
|
||||
await self._add_word_timestamps(
|
||||
[(frame.word, frame.timestamp)],
|
||||
frame.context_id,
|
||||
includes_inter_frame_spaces=frame.includes_inter_frame_spaces,
|
||||
)
|
||||
continue
|
||||
elif isinstance(frame, TTSAudioRawFrame):
|
||||
@@ -1416,6 +1474,9 @@ class TTSService(AIService):
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
should_push_stop_frame = self._push_stop_frames
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
# Checking if we have any remaining spoken slots before pushing the TTSStoppedFrame
|
||||
await self._apply_force_complete()
|
||||
|
||||
should_push_stop_frame = False
|
||||
# Setting the last word timestamp as the TTSStoppedFrame PTS
|
||||
if not frame.pts:
|
||||
@@ -1433,8 +1494,11 @@ class TTSService(AIService):
|
||||
should_push_stop_frame = False
|
||||
break
|
||||
|
||||
await self._apply_force_complete()
|
||||
|
||||
if should_push_stop_frame and self._push_stop_frames:
|
||||
await self.push_frame(TTSStoppedFrame(context_id=context_id))
|
||||
|
||||
await self._maybe_reset_word_timestamps()
|
||||
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
|
||||
@@ -448,6 +448,9 @@ class BaseOutputTransport(FrameProcessor):
|
||||
self._video_task: asyncio.Task | None = None
|
||||
self._clock_task: asyncio.Task | None = None
|
||||
|
||||
# If timestamps are equal, use this count to preserve the insertion order
|
||||
self._clock_queue_counter = itertools.count()
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""Get the audio sample rate.
|
||||
@@ -498,7 +501,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
frame: The end frame signaling sender shutdown.
|
||||
"""
|
||||
# Let the sink tasks process the queue until they reach this EndFrame.
|
||||
await self._clock_queue.put((float("inf"), frame.id, frame))
|
||||
await self._clock_queue.put((float("inf"), next(self._clock_queue_counter), frame))
|
||||
await self._audio_queue.put(frame)
|
||||
|
||||
# At this point we have enqueued an EndFrame and we need to wait for
|
||||
@@ -610,7 +613,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
Args:
|
||||
frame: The frame with timing information to handle.
|
||||
"""
|
||||
await self._clock_queue.put((frame.pts, frame.id, frame))
|
||||
await self._clock_queue.put((frame.pts, next(self._clock_queue_counter), frame))
|
||||
|
||||
async def handle_sync_frame(self, frame: Frame):
|
||||
"""Handle frames that need synchronized processing.
|
||||
|
||||
354
src/pipecat/utils/context/aggregated_frame_sequencer.py
Normal file
354
src/pipecat/utils/context/aggregated_frame_sequencer.py
Normal file
@@ -0,0 +1,354 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Ordered sequencer for AggregatedTextFrame slots through TTS processing."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AggregationType,
|
||||
Frame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.utils.context.word_completion_tracker import WordCompletionTracker
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AggregatedFrameSlot:
|
||||
"""Ordered slot tracking one AggregatedTextFrame through TTS processing.
|
||||
|
||||
Every frame that passes through _push_tts_frames — whether spoken or skipped —
|
||||
occupies a slot in the sequencer. Skipped frames wait at their position and are
|
||||
emitted downstream only after all preceding spoken slots are complete, preserving
|
||||
correct context ordering.
|
||||
"""
|
||||
|
||||
frame: AggregatedTextFrame
|
||||
context_id: str
|
||||
spoken: bool
|
||||
tracker: WordCompletionTracker | None = None
|
||||
transport_destination: str | None = None
|
||||
complete: bool = False
|
||||
|
||||
|
||||
class AggregatedFrameSequencer:
|
||||
"""Sequences AggregatedTextFrame slots to preserve TTS context ordering.
|
||||
|
||||
Manages an ordered queue of spoken and skipped TTS slots. Spoken slots are tracked
|
||||
via a :class:`WordCompletionTracker`; skipped slots (e.g. code blocks excluded from
|
||||
TTS synthesis) wait in-place until all preceding spoken slots are complete, then are
|
||||
flushed downstream with ``append_to_context=True``.
|
||||
|
||||
All methods are synchronous and return lists of frames the caller should push
|
||||
downstream, making the sequencer fully testable without any async machinery.
|
||||
|
||||
Example::
|
||||
|
||||
sequencer = AggregatedFrameSequencer()
|
||||
sequencer.register_spoken(frame, ctx_id, tracker, append_to_context=True)
|
||||
for f in sequencer.process_word("hello", pts=1000, context_id=ctx_id):
|
||||
await self.push_frame(f)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "AggregatedFrameSequencer"):
|
||||
"""Initialize the sequencer.
|
||||
|
||||
Args:
|
||||
name: Label used in log messages (typically the owning TTS service name).
|
||||
"""
|
||||
self._name = name
|
||||
self._slots: list[_AggregatedFrameSlot] = []
|
||||
self._context_append_to_context: dict[str, bool] = {}
|
||||
|
||||
def register_spoken(
|
||||
self,
|
||||
frame: AggregatedTextFrame,
|
||||
context_id: str,
|
||||
tracker: WordCompletionTracker | None,
|
||||
append_to_context: bool,
|
||||
) -> None:
|
||||
"""Register a spoken AggregatedTextFrame slot.
|
||||
|
||||
Called from _push_tts_frames for frames sent to the TTS service. The slot is
|
||||
marked complete either via :meth:`process_word` (word-timestamp services) or
|
||||
:meth:`complete_spoken_slot` (push_text_frames=True services).
|
||||
|
||||
Args:
|
||||
frame: The AggregatedTextFrame being spoken.
|
||||
context_id: The TTS context ID assigned to this frame.
|
||||
tracker: WordCompletionTracker for word-timestamp services; None for
|
||||
push_text_frames=True services (they complete via complete_spoken_slot).
|
||||
append_to_context: Whether word frames built for this context should carry
|
||||
append_to_context=True.
|
||||
"""
|
||||
self._context_append_to_context[context_id] = append_to_context
|
||||
self._slots.append(
|
||||
_AggregatedFrameSlot(
|
||||
frame=frame,
|
||||
context_id=context_id,
|
||||
spoken=True,
|
||||
tracker=tracker,
|
||||
)
|
||||
)
|
||||
|
||||
def register_skipped(
|
||||
self,
|
||||
frame: AggregatedTextFrame,
|
||||
context_id: str,
|
||||
transport_destination: str | None,
|
||||
) -> list[Frame]:
|
||||
"""Register a skipped AggregatedTextFrame and attempt an immediate flush.
|
||||
|
||||
The frame is appended as a skipped slot. If no incomplete spoken slot precedes
|
||||
it, the frame is returned right away; otherwise it waits until a later
|
||||
:meth:`flush` unblocks it.
|
||||
|
||||
Args:
|
||||
frame: The skipped AggregatedTextFrame (e.g. a code block).
|
||||
context_id: The context ID assigned in _push_tts_frames.
|
||||
transport_destination: Transport routing value to attach at flush time.
|
||||
|
||||
Returns:
|
||||
Frames to push downstream (empty when blocked by a preceding spoken slot).
|
||||
"""
|
||||
frame.context_id = context_id
|
||||
self._slots.append(
|
||||
_AggregatedFrameSlot(
|
||||
frame=frame,
|
||||
context_id=context_id,
|
||||
spoken=False,
|
||||
transport_destination=transport_destination,
|
||||
)
|
||||
)
|
||||
return self.flush()
|
||||
|
||||
def process_word(
|
||||
self,
|
||||
word: str,
|
||||
pts: int,
|
||||
context_id: str | None,
|
||||
) -> list[Frame]:
|
||||
"""Process one word-timestamp event and return frames to push downstream.
|
||||
|
||||
Locates the active (first incomplete spoken) slot with a tracker, advances it
|
||||
by the incoming word, and builds a :class:`TTSTextFrame`. Handles:
|
||||
|
||||
- Normal words that fit entirely within the active slot.
|
||||
- Overflow words straddling two slot boundaries.
|
||||
- Force-complete when the TTS drops an event (word belongs to the next slot).
|
||||
- Passthrough for words not recognised by any slot.
|
||||
- Flushes any skipped slots unblocked by slot completion.
|
||||
|
||||
Args:
|
||||
word: A word token from the TTS service word-timestamp stream.
|
||||
pts: Presentation timestamp (nanoseconds) to assign to the frame.
|
||||
context_id: TTS context ID from the word-timestamp event.
|
||||
|
||||
Returns:
|
||||
Ordered list of frames (TTSTextFrame and/or AggregatedTextFrame) to push.
|
||||
"""
|
||||
active = self._get_active_slot()
|
||||
is_complete = False
|
||||
raw_overflow_word = None
|
||||
|
||||
if active and active.tracker:
|
||||
if not active.tracker.word_belongs_here(word):
|
||||
next_slot = self._get_next_active_slot(active)
|
||||
word_fits_next = (
|
||||
next_slot is not None
|
||||
and next_slot.tracker is not None
|
||||
and next_slot.tracker.word_belongs_here(word)
|
||||
)
|
||||
if not word_fits_next:
|
||||
logger.warning(
|
||||
f"{self._name} Word '{word}' not recognised by any slot, "
|
||||
"emitting as passthrough"
|
||||
)
|
||||
return [self._build_word_frame(word, pts, context_id)]
|
||||
|
||||
is_complete = active.tracker.add_word_and_check_complete(word)
|
||||
raw_overflow_word = active.tracker.get_overflow_word()
|
||||
|
||||
frame_text = (
|
||||
active.tracker.get_word_for_frame() if (active and active.tracker) else word
|
||||
) or word
|
||||
raw_text = active.tracker.get_llm_consumed() if (active and active.tracker) else None
|
||||
emit_context_id = active.context_id if active else context_id
|
||||
|
||||
# logger.debug(f"{self._name} Word '{word}' → frame_text='{frame_text}', raw='{raw_text}'")
|
||||
frames: list[Frame] = [
|
||||
self._build_word_frame(frame_text, pts, emit_context_id, raw_text=raw_text)
|
||||
]
|
||||
|
||||
if is_complete and active:
|
||||
active.complete = True
|
||||
frames.extend(self.flush(last_word_pts=pts))
|
||||
if raw_overflow_word:
|
||||
logger.debug(f"{self._name} Emitting overflow word '{raw_overflow_word}'")
|
||||
frames.extend(self._process_overflow(raw_overflow_word, pts))
|
||||
|
||||
return frames
|
||||
|
||||
def complete_spoken_slot(self) -> list[Frame]:
|
||||
"""Mark the first pending spoken slot complete and flush unblocked skipped frames.
|
||||
|
||||
Used by push_text_frames=True services: after the TTSTextFrame has been appended
|
||||
to the audio context, this marks the spoken slot done and releases any skipped
|
||||
frames waiting behind it.
|
||||
|
||||
Returns:
|
||||
AggregatedTextFrame(s) that are now unblocked and should be pushed.
|
||||
"""
|
||||
slot = next((s for s in self._slots if s.spoken and not s.complete), None)
|
||||
if slot:
|
||||
slot.complete = True
|
||||
return self.flush()
|
||||
|
||||
def flush(self, last_word_pts: int | None = None) -> list[Frame]:
|
||||
"""Walk the slot queue and return all skipped frames that are now unblocked.
|
||||
|
||||
Removes complete spoken slots from the head of the queue, then emits (and
|
||||
removes) skipped slots whose preceding spoken slots are all done. Stops at
|
||||
the first incomplete spoken slot.
|
||||
|
||||
Args:
|
||||
last_word_pts: When provided, skipped frames receive this PTS so they
|
||||
appear immediately after the last spoken word in the timeline.
|
||||
|
||||
Returns:
|
||||
AggregatedTextFrame(s) ready to be pushed downstream.
|
||||
"""
|
||||
frames: list[Frame] = []
|
||||
while self._slots:
|
||||
slot = self._slots[0]
|
||||
if slot.spoken and slot.complete:
|
||||
self._slots.pop(0)
|
||||
elif not slot.spoken and not slot.complete:
|
||||
slot.frame.append_to_context = True
|
||||
slot.frame.transport_destination = slot.transport_destination
|
||||
if last_word_pts:
|
||||
slot.frame.pts = last_word_pts
|
||||
logger.debug(f"{self._name}: Flushing Aggregated Frame {slot.frame}")
|
||||
frames.append(slot.frame)
|
||||
slot.complete = True
|
||||
self._slots.pop(0)
|
||||
else:
|
||||
break # spoken but not yet complete — wait
|
||||
return frames
|
||||
|
||||
def force_complete(self, last_word_pts: int) -> list[Frame]:
|
||||
"""Force-complete all incomplete spoken slots and flush skipped frames.
|
||||
|
||||
Called at the end of an audio context to handle TTS providers that silently drop
|
||||
word-timestamp events. Emits a TTSTextFrame for any remaining unspoken text in
|
||||
each incomplete slot, marks it complete, then flushes all now-unblocked skipped
|
||||
frames.
|
||||
|
||||
Args:
|
||||
last_word_pts: PTS of the last received word frame, used as the PTS for
|
||||
force-completed frames and forwarded to :meth:`flush`.
|
||||
|
||||
Returns:
|
||||
Combined list of TTSTextFrames (for incomplete spoken slots) and
|
||||
AggregatedTextFrames (skipped slots now unblocked), in emission order.
|
||||
"""
|
||||
frames: list[Frame] = []
|
||||
for slot in self._slots:
|
||||
if slot.spoken and not slot.complete:
|
||||
if slot.tracker:
|
||||
remaining_text = slot.tracker.get_remaining_tts_text()
|
||||
raw_remaining = slot.tracker.get_remaining_llm_text()
|
||||
if raw_remaining and remaining_text and remaining_text not in raw_remaining:
|
||||
logger.warning(
|
||||
f"{self._name} force-complete: raw_remaining {repr(raw_remaining)} "
|
||||
f"does not contain remaining_text {repr(remaining_text)}, discarding"
|
||||
)
|
||||
raw_remaining = None
|
||||
if remaining_text:
|
||||
logger.debug(
|
||||
f"{self._name} force-completing slot with remaining text "
|
||||
f"{repr(remaining_text)}"
|
||||
)
|
||||
frames.append(
|
||||
self._build_word_frame(
|
||||
remaining_text,
|
||||
last_word_pts,
|
||||
slot.context_id,
|
||||
raw_text=raw_remaining,
|
||||
)
|
||||
)
|
||||
slot.complete = True
|
||||
frames.extend(self.flush(last_word_pts=last_word_pts))
|
||||
return frames
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all slots and context metadata (called on interruption/reset)."""
|
||||
self._slots.clear()
|
||||
self._context_append_to_context.clear()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _get_active_slot(self) -> _AggregatedFrameSlot | None:
|
||||
"""Return the first incomplete spoken slot that has a tracker."""
|
||||
return next(
|
||||
(s for s in self._slots if s.spoken and not s.complete and s.tracker is not None),
|
||||
None,
|
||||
)
|
||||
|
||||
def _get_next_active_slot(self, current: _AggregatedFrameSlot) -> _AggregatedFrameSlot | None:
|
||||
"""Return the first incomplete spoken slot with a tracker after *current*."""
|
||||
found = False
|
||||
for s in self._slots:
|
||||
if s is current:
|
||||
found = True
|
||||
continue
|
||||
if found and s.spoken and not s.complete and s.tracker is not None:
|
||||
return s
|
||||
return None
|
||||
|
||||
def _build_word_frame(
|
||||
self,
|
||||
text: str,
|
||||
pts: int,
|
||||
context_id: str | None,
|
||||
raw_text: str | None = None,
|
||||
) -> Frame:
|
||||
"""Build a TTSTextFrame with all standard word-timestamp attributes set."""
|
||||
frame = TTSTextFrame(text, aggregated_by=AggregationType.WORD)
|
||||
frame.pts = pts
|
||||
frame.context_id = context_id
|
||||
frame.append_to_context = (
|
||||
self._context_append_to_context.get(context_id, True)
|
||||
if context_id is not None
|
||||
else True
|
||||
)
|
||||
frame.raw_text = raw_text
|
||||
return frame
|
||||
|
||||
def _process_overflow(self, raw_overflow_word: str, pts: int) -> list[Frame]:
|
||||
"""Feed an overflow suffix into the next active slot and return resulting frames."""
|
||||
frames: list[Frame] = []
|
||||
next_active = self._get_active_slot()
|
||||
if not next_active or not next_active.tracker:
|
||||
return frames
|
||||
overflow_complete = next_active.tracker.add_word_and_check_complete(raw_overflow_word)
|
||||
frames.append(
|
||||
self._build_word_frame(
|
||||
raw_overflow_word,
|
||||
pts,
|
||||
next_active.context_id,
|
||||
raw_text=next_active.tracker.get_llm_consumed(),
|
||||
)
|
||||
)
|
||||
if overflow_complete:
|
||||
next_active.complete = True
|
||||
frames.extend(self.flush(last_word_pts=pts))
|
||||
return frames
|
||||
490
src/pipecat/utils/context/word_completion_tracker.py
Normal file
490
src/pipecat/utils/context/word_completion_tracker.py
Normal file
@@ -0,0 +1,490 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Word completion tracker for TTS context ordering."""
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class WordCompletionTracker:
|
||||
"""Tracks whether all words from a source AggregatedTextFrame have been spoken.
|
||||
|
||||
Compares normalized alphanumeric character counts between the TTS text and
|
||||
accumulated spoken words, making the check robust to punctuation, spacing,
|
||||
and XML/HTML tags (e.g. SSML tags like ``<spell>...</spell>`` returned by some
|
||||
TTS providers in word-timestamp events).
|
||||
|
||||
When ``llm_text`` is provided (e.g. the original pattern-matched text including
|
||||
delimiters like ``<card>4111 1111 1111 1111</card>``), the tracker additionally
|
||||
maps each spoken word back to its corresponding span in that LLM text. This
|
||||
lets callers attach the original text to ``TTSTextFrame`` entries so the
|
||||
conversation context receives properly-tagged content rather than the cleaned
|
||||
words received from the TTS provider.
|
||||
|
||||
Background: TTS providers apply their own SSML tags to the text before
|
||||
synthesis and return word-timestamp events containing the raw spoken words
|
||||
(e.g. ``"4111"``, ``"1111"``). Without LLM-text tracking, the conversation
|
||||
context would only see those cleaned words and lose the original structure
|
||||
(e.g. ``<card>4111 1111 1111 1111</card>``). By mapping normalized char counts
|
||||
back to positions in ``llm_text``, each TTSTextFrame can carry the exact span
|
||||
of original text it represents.
|
||||
|
||||
Overflow handling: TTS providers sometimes return a single word token that
|
||||
spans the boundary between two AggregatedTextFrames (e.g. ``"1111</spell>And"``
|
||||
when one frame ends with ``1111</card>`` and the next begins with ``And``). The
|
||||
tracker detects this and exposes the raw overflow suffix via ``get_overflow_word()``,
|
||||
so callers can feed the remainder into the next frame's tracker and emit a
|
||||
correctly-attributed TTSTextFrame for each part.
|
||||
|
||||
Example::
|
||||
|
||||
tracker = WordCompletionTracker("Hello, world!")
|
||||
tracker.add_word_and_check_complete("Hello") # False
|
||||
tracker.add_word_and_check_complete("world") # True — normalized "helloworld" >= "helloworld"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tts_text: str,
|
||||
llm_text: str | None = None,
|
||||
):
|
||||
"""Initialize the tracker with the text of the frame being spoken.
|
||||
|
||||
Args:
|
||||
tts_text: Full text of the AggregatedTextFrame sent to TTS (may include
|
||||
TTS-specific SSML tags). Used for normalized char-count completion
|
||||
tracking and as the cursor reference for the TTS word stream.
|
||||
llm_text: Original LLM-produced text including pattern delimiters (e.g.
|
||||
``<card>4111 1111 1111 1111</card>``). When provided, each
|
||||
``add_word_and_check_complete`` call also returns the corresponding
|
||||
LLM span via ``get_llm_consumed()``. Both texts normalize to the
|
||||
same alphanumeric sequence, so the same char-count cursor drives
|
||||
position tracking in both.
|
||||
"""
|
||||
self._tts_normalized = self._normalize(tts_text)
|
||||
self._received = ""
|
||||
|
||||
# _tts_text is the original tts_text before normalization.
|
||||
# _tts_pos is a cursor into it, advanced by the same alnum count
|
||||
# as the TTS word stream, so the force-complete path can emit the remaining
|
||||
# unspoken text as a TTSTextFrame instead of silently dropping it.
|
||||
self._tts_text = tts_text
|
||||
self._tts_pos = 0
|
||||
|
||||
# _llm_text is the original LLM-produced text (with pattern delimiters like
|
||||
# <card>...</card>). We track _llm_pos as a cursor into it, advancing
|
||||
# by the same number of alphanumeric chars consumed from the TTS word stream.
|
||||
self._llm_text = llm_text
|
||||
self._llm_pos = 0
|
||||
|
||||
self._overflow_word: str | None = None
|
||||
self._llm_consumed: str | None = None
|
||||
self._frame_word: str | None = None
|
||||
logger.debug(f"WordCompletionTracker: {self._tts_normalized}")
|
||||
|
||||
@staticmethod
|
||||
def _normalize(text: str) -> str:
|
||||
"""Strip XML/HTML tags then keep only lowercase alphanumeric characters.
|
||||
|
||||
Accented letters (e.g. ã, é) are reduced to their base letter so TTS output
|
||||
can be matched against LLM text even when the provider strips diacritics.
|
||||
Non-Latin scripts (CJK, Hangul) are kept as-is — each original character
|
||||
contributes exactly one char to the result, keeping normalized length in sync
|
||||
with raw alnum counts used by _advance_by_alnums.
|
||||
"""
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
result = []
|
||||
for char in text:
|
||||
# Ignore punctuation, spaces, emojis, etc.
|
||||
# Keep only letters and numbers.
|
||||
if not char.isalnum():
|
||||
continue
|
||||
# NFD decomposes accented characters into:
|
||||
# é -> e + ◌́
|
||||
# ã -> a + ◌̃
|
||||
#
|
||||
# Non-accented characters usually stay unchanged.
|
||||
nfd = unicodedata.normalize("NFD", char)
|
||||
# Unicode category "Mn" means:
|
||||
# Mark, Nonspacing
|
||||
#
|
||||
# These are combining accent marks that modify
|
||||
# the previous character but are not standalone.
|
||||
#
|
||||
# Example:
|
||||
# "é" becomes:
|
||||
# nfd[0] = "e"
|
||||
# nfd[1] = "◌́" (category = "Mn")
|
||||
#
|
||||
# If the second character is a combining accent,
|
||||
# keep only the base letter.
|
||||
if len(nfd) >= 2 and unicodedata.category(nfd[1]) == "Mn":
|
||||
# Accented letter: keep the base character only (drops the combining mark).
|
||||
result.append(nfd[0].lower())
|
||||
else:
|
||||
# Regular ASCII, numbers, CJK, Hangul, etc.
|
||||
# are kept unchanged (except lowercase conversion).
|
||||
result.append(char.lower())
|
||||
return "".join(result)
|
||||
|
||||
# Typographic variants that LLMs commonly emit but TTS services normalize away.
|
||||
_TYPOGRAPHY_FOLD = str.maketrans(
|
||||
{
|
||||
"‘": "'", # ' LEFT SINGLE QUOTATION MARK
|
||||
"’": "'", # ' RIGHT SINGLE QUOTATION MARK
|
||||
"ʼ": "'", # ʼ MODIFIER LETTER APOSTROPHE
|
||||
"“": '"', # " LEFT DOUBLE QUOTATION MARK
|
||||
"”": '"', # " RIGHT DOUBLE QUOTATION MARK
|
||||
"–": "-", # – EN DASH
|
||||
"—": "-", # — EM DASH
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fold_typography(text: str) -> str:
|
||||
"""Replace typographic punctuation variants with their ASCII equivalents."""
|
||||
return text.translate(WordCompletionTracker._TYPOGRAPHY_FOLD)
|
||||
|
||||
@staticmethod
|
||||
def _remove_trailing_punctuation(text: str) -> str:
|
||||
"""Remove punctuation only at the very end of the given text."""
|
||||
i = len(text)
|
||||
while i > 0 and unicodedata.category(text[i - 1]).startswith("P"):
|
||||
i -= 1
|
||||
return text[:i]
|
||||
|
||||
@staticmethod
|
||||
def _advance_by_alnums(text: str, start_pos: int, n: int) -> int:
|
||||
"""Return the position in *text* after advancing past *n* alphanumeric chars.
|
||||
|
||||
Moves through the text one character at a time, counting only alphanumeric
|
||||
characters. XML/HTML tags (``<...>``) are skipped entirely — their content
|
||||
is not counted against the budget, so the returned span includes the full tag.
|
||||
Other non-alphanumeric characters (spaces, punctuation) are also passed over
|
||||
without decrementing the budget.
|
||||
|
||||
After the *n* alnum chars are consumed, advances further past any immediately
|
||||
following punctuation (e.g. the ``,`` in ``"questions,"`` or the ``.`` in
|
||||
``"done."``), stopping before the next space, alnum char, or XML tag.
|
||||
|
||||
Args:
|
||||
text: The source text to scan.
|
||||
start_pos: Starting position in *text*.
|
||||
n: Number of alphanumeric characters to consume.
|
||||
"""
|
||||
pos = start_pos
|
||||
count = 0
|
||||
while pos < len(text) and count < n:
|
||||
if text[pos] == "<":
|
||||
end = text.find(">", pos)
|
||||
pos = end + 1 if end != -1 else pos + 1
|
||||
elif text[pos].isalnum():
|
||||
count += 1
|
||||
pos += 1
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
while pos < len(text):
|
||||
if text[pos] == "<":
|
||||
break
|
||||
if text[pos].isalnum() or text[pos].isspace():
|
||||
break
|
||||
pos += 1
|
||||
|
||||
return pos
|
||||
|
||||
def add_word_and_check_complete(self, word: str) -> bool:
|
||||
"""Record a spoken word from a word-timestamp event.
|
||||
|
||||
Normalizes ``word``, appends it to the running total, and checks whether
|
||||
all expected alphanumeric characters have been covered.
|
||||
|
||||
Before advancing, checks whether the word belongs to this frame via
|
||||
``word_belongs_here``. If it does not (e.g. the TTS provider silently
|
||||
dropped a word-timestamp), the slot is force-completed: the remaining
|
||||
unspoken text from ``tts_text`` is stored in ``_frame_word`` so a
|
||||
TTSTextFrame can still be emitted for the dropped portion, all remaining
|
||||
``llm_text`` is consumed, and the entire incoming word is set as overflow
|
||||
so the caller's overflow path routes it to the next slot unchanged.
|
||||
|
||||
If ``llm_text`` was provided at construction time, also advances the LLM
|
||||
cursor by the same number of alphanumeric chars consumed from this word and
|
||||
stores the corresponding LLM span in ``_llm_consumed``. When this word
|
||||
completes the frame, the entire remaining LLM text (including any closing
|
||||
tags) is consumed so nothing is lost.
|
||||
|
||||
If the word overshoots the expected length (overflow), the raw suffix of
|
||||
the word (everything after the last char belonging to this frame) is stored
|
||||
in ``_overflow_word``, so the caller can attribute it to the next
|
||||
AggregatedTextFrame.
|
||||
|
||||
Args:
|
||||
word: A single word token returned by the TTS service. TTS services that
|
||||
emit spaces and punctuation as separate tokens (e.g. Inworld) must
|
||||
pre-merge those tokens into the preceding word before calling this
|
||||
method (see ``TTSService._merge_punct_tokens``).
|
||||
|
||||
Returns:
|
||||
True when all expected content has been covered.
|
||||
"""
|
||||
normalized = self._normalize(word)
|
||||
|
||||
prev_len = len(self._received)
|
||||
expected_len = len(self._tts_normalized)
|
||||
|
||||
self._overflow_word = None
|
||||
self._llm_consumed = None
|
||||
self._frame_word = None
|
||||
|
||||
if prev_len > expected_len:
|
||||
logger.warning(f"{self}, trying to add a word in an already complete frame")
|
||||
return True
|
||||
|
||||
# If the word doesn't match the next expected chars, the TTS provider
|
||||
# likely dropped a word-timestamp event. Force-complete this slot: emit the
|
||||
# remaining TTS text as _frame_word so a TTSTextFrame is still produced
|
||||
# for the unspoken portion, consume all remaining llm_text, and route the
|
||||
# entire incoming word as overflow for the next slot.
|
||||
if not self.word_belongs_here(word):
|
||||
self._frame_word = self._tts_text[self._tts_pos :]
|
||||
if self._llm_text is not None:
|
||||
self._llm_consumed = self._llm_text[self._llm_pos :]
|
||||
self._llm_pos = len(self._llm_text)
|
||||
# This should not happen: force-complete sweeps all remaining
|
||||
# llm_text, so the span must contain the frame word. If it
|
||||
# doesn't, tts_text and llm_text are out of sync in an
|
||||
# unexpected way — discard rather than returning a corrupt span.
|
||||
# Also removing punctuation from the frame word to match the
|
||||
# expected text, since some TTS services may add punctuation to
|
||||
# the raw text.
|
||||
word_without_punctuation = self._remove_trailing_punctuation(self._frame_word)
|
||||
if word_without_punctuation and word_without_punctuation not in self._llm_consumed:
|
||||
logger.warning(
|
||||
f"WordCompletionTracker: force-complete llm_consumed {repr(self._llm_consumed)!s} "
|
||||
f"does not contain frame_word {repr(self._frame_word)!s}, discarding"
|
||||
)
|
||||
self._llm_consumed = None
|
||||
self._received = self._tts_normalized # force-complete
|
||||
self._overflow_word = word
|
||||
return True
|
||||
|
||||
self._received += normalized
|
||||
|
||||
# How many normalized chars from this word belong to the current frame.
|
||||
chars_for_frame = min(len(normalized), expected_len - prev_len)
|
||||
|
||||
if prev_len + len(normalized) > expected_len:
|
||||
# This word straddles the frame boundary. Split into:
|
||||
# - _frame_word: the prefix of `word` up to the split point, used
|
||||
# for the TTSTextFrame of the current slot.
|
||||
# - raw overflow word: the raw suffix after the split point, used
|
||||
# to build a TTSTextFrame attributed to the next AggregatedTextFrame.
|
||||
split_pos = self._advance_by_alnums(word, 0, chars_for_frame)
|
||||
self._frame_word = word[:split_pos]
|
||||
self._overflow_word = word[split_pos:]
|
||||
else:
|
||||
# Word fits entirely in this frame.
|
||||
self._frame_word = word
|
||||
|
||||
# Advance the TTS cursor by the same alnum count so the force-complete
|
||||
# path knows where in _tts_text to start from.
|
||||
self._tts_pos = self._advance_by_alnums(self._tts_text, self._tts_pos, chars_for_frame)
|
||||
|
||||
if self._llm_text is not None:
|
||||
if self.is_complete:
|
||||
# Consume ALL remaining LLM text: closing tags (e.g. </card>)
|
||||
# and any trailing punctuation that the TTS will not send separately.
|
||||
self._llm_consumed = self._llm_text[self._llm_pos :]
|
||||
self._llm_pos = len(self._llm_text)
|
||||
else:
|
||||
if chars_for_frame == 0:
|
||||
# Consume exactly the raw word in llm_text, skipping any
|
||||
# leading spaces that belong to the previous token's span.
|
||||
start = self._llm_pos
|
||||
while start < len(self._llm_text) and self._llm_text[start].isspace():
|
||||
start += 1
|
||||
end = start + len(word)
|
||||
self._llm_consumed = self._llm_text[start:end]
|
||||
self._llm_pos = end
|
||||
else:
|
||||
# Advance through llm_text by exactly chars_for_frame alphanumeric
|
||||
# chars. Non-alnum chars (spaces, opening tags) are included in the
|
||||
# slice, preserving the original formatting for the context.
|
||||
new_pos = self._advance_by_alnums(
|
||||
self._llm_text, self._llm_pos, chars_for_frame
|
||||
)
|
||||
self._llm_consumed = self._llm_text[self._llm_pos : new_pos]
|
||||
self._llm_pos = new_pos
|
||||
# This should not happen: the LLM cursor is driven by the same
|
||||
# alnum count as the word stream, so the consumed span must contain
|
||||
# the frame word. If it doesn't, the cursors drifted out of sync
|
||||
# in an unexpected way — discard rather than returning a corrupt span.
|
||||
# Also removing punctuation from the frame word to match the
|
||||
# expected text, since some TTS services may add punctuation to
|
||||
# the raw text.
|
||||
word_without_punctuation = self._remove_trailing_punctuation(self._frame_word)
|
||||
if word_without_punctuation and self._fold_typography(
|
||||
word_without_punctuation
|
||||
) not in self._fold_typography(self._llm_consumed):
|
||||
logger.warning(
|
||||
f"WordCompletionTracker: llm_consumed {repr(self._llm_consumed)!s} "
|
||||
f"does not contain frame_word {repr(self._frame_word)!s}, discarding"
|
||||
)
|
||||
self._llm_consumed = None
|
||||
|
||||
return self.is_complete
|
||||
|
||||
def word_belongs_here(self, word: str) -> bool:
|
||||
"""Return True if this word plausibly belongs to the remaining TTS text.
|
||||
|
||||
Dispatches to one of two checks depending on whether the word contains
|
||||
any alphanumeric characters after normalization:
|
||||
|
||||
- Alnum words: prefix-match against the remaining expected chars.
|
||||
- Symbol/punctuation words (empty after normalization): literal substring
|
||||
search in the remaining raw TTS text, with a fallback for TTS providers
|
||||
that substitute Unicode symbols with ASCII punctuation.
|
||||
|
||||
Used to detect when the TTS provider silently dropped a word-timestamp
|
||||
event: if the incoming word does not match this slot's remaining content,
|
||||
the caller should force-complete this slot and route the word to the next.
|
||||
"""
|
||||
normalized = self._normalize(word)
|
||||
if normalized:
|
||||
return self._alnum_word_belongs_here(normalized)
|
||||
else:
|
||||
return self._symbol_word_belongs_here(word)
|
||||
|
||||
def _alnum_word_belongs_here(self, normalized: str) -> bool:
|
||||
"""Return True if an alnum-containing word matches this frame's remaining expected chars.
|
||||
|
||||
Accepts both full words and partial tokens — the word belongs here as long
|
||||
as its normalized characters are a prefix of what is still expected. This
|
||||
also handles the overflow case where the word is longer than the remaining
|
||||
content (the excess is detected and split in ``add_word_and_check_complete``).
|
||||
"""
|
||||
remaining = self._tts_normalized[len(self._received) :]
|
||||
if not remaining:
|
||||
return False
|
||||
check_len = min(len(normalized), len(remaining))
|
||||
return remaining.startswith(normalized[:check_len])
|
||||
|
||||
def _symbol_word_belongs_here(self, word: str) -> bool:
|
||||
"""Return True if a non-alnum word (emoji, punctuation, symbol) belongs to this frame.
|
||||
|
||||
Two checks are applied in order:
|
||||
|
||||
1. **Literal substring**: search for the raw word in the remaining TTS text.
|
||||
``_advance_by_alnums`` may have already moved ``_tts_pos`` past some trailing
|
||||
punctuation, so the search window is backed up to include those characters.
|
||||
|
||||
2. **Symbol substitution fallback**: some TTS providers substitute Unicode symbols
|
||||
with ASCII punctuation in word-timestamp events (e.g. ElevenLabs reports ``→``
|
||||
as ``-``), so check 1 always fails even though the word belongs here. If alnum
|
||||
content still remains unconsumed and the next non-space character in the TTS
|
||||
text is itself a non-alnum symbol, accept the word as a substitution.
|
||||
"""
|
||||
search_start = self._tts_pos
|
||||
while search_start > 0:
|
||||
ch = self._tts_text[search_start - 1]
|
||||
if ch.isalnum() or ch.isspace() or ch == ">":
|
||||
break
|
||||
search_start -= 1
|
||||
if word in self._tts_text[search_start:]:
|
||||
return True
|
||||
|
||||
if len(self._received) >= len(self._tts_normalized):
|
||||
return False
|
||||
|
||||
pos = self._tts_pos
|
||||
while pos < len(self._tts_text) and self._tts_text[pos].isspace():
|
||||
pos += 1
|
||||
return pos < len(self._tts_text) and not self._tts_text[pos].isalnum()
|
||||
|
||||
def get_word_for_frame(self) -> str | None:
|
||||
"""Return the portion of the last word that belongs to this frame.
|
||||
|
||||
- Normal word (no overflow): the full word.
|
||||
- Straddling word: the prefix up to the frame boundary (e.g. ``"1111"``
|
||||
from ``"1111 And"``).
|
||||
- Force-completed (word didn't belong): the remaining unspoken text from
|
||||
``tts_text`` so a TTSTextFrame can still be emitted for the dropped
|
||||
portion. The incoming word is routed as overflow to the next slot.
|
||||
"""
|
||||
return self._frame_word.strip() if self._frame_word else self._frame_word
|
||||
|
||||
def get_overflow_word(self) -> str | None:
|
||||
"""Return the raw suffix of the last word that overflows into the next frame.
|
||||
|
||||
Preserves the original casing and any non-alphanumeric characters so the
|
||||
overflow TTSTextFrame has natural word text. Returns None when there is no
|
||||
overflow (the word fit entirely within this frame).
|
||||
"""
|
||||
return self._overflow_word.strip() if self._overflow_word else self._overflow_word
|
||||
|
||||
def get_llm_consumed(self) -> str | None:
|
||||
"""Return the LLM text span consumed for the last added word.
|
||||
|
||||
Returns None if no llm_text was provided at construction time.
|
||||
"""
|
||||
return self._llm_consumed.strip() if self._llm_consumed else self._llm_consumed
|
||||
|
||||
def get_accumulated_tts_text(self) -> str:
|
||||
"""Return all consumed text from tts_text up to the current cursor position.
|
||||
|
||||
Unlike ``get_word_for_frame()`` (which reflects only the last word), this returns
|
||||
everything that has been consumed since construction or the last ``reset()``.
|
||||
"""
|
||||
return self._tts_text[: self._tts_pos]
|
||||
|
||||
def get_accumulated_llm_text(self) -> str | None:
|
||||
"""Return all consumed text from llm_text up to the current cursor position.
|
||||
|
||||
Unlike ``get_llm_consumed()`` (which reflects only the last word), this returns
|
||||
everything that has been consumed since construction or the last ``reset()``.
|
||||
Returns None if no llm_text was provided at construction time.
|
||||
"""
|
||||
if self._llm_text is None:
|
||||
return None
|
||||
return self._llm_text[: self._llm_pos]
|
||||
|
||||
def get_remaining_tts_text(self) -> str:
|
||||
"""Return the unspoken portion of tts_text, stripped of leading/trailing whitespace.
|
||||
|
||||
This is the text that the TTS provider has not yet confirmed via word-timestamp
|
||||
events. Useful for force-completing a slot when the audio context ends before all
|
||||
word-timestamp events have arrived.
|
||||
"""
|
||||
return self._tts_text[self._tts_pos :].strip()
|
||||
|
||||
def get_remaining_llm_text(self) -> str | None:
|
||||
"""Return the unspoken portion of llm_text, stripped of leading/trailing whitespace.
|
||||
|
||||
Returns None if no llm_text was provided at construction time. Like
|
||||
``get_remaining_tts_text()``, intended for force-completing a slot so that the
|
||||
conversation context receives the full original text.
|
||||
"""
|
||||
if self._llm_text is None:
|
||||
return None
|
||||
remaining = self._llm_text[self._llm_pos :].strip()
|
||||
return remaining if remaining else None
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""True when accumulated normalized chars >= expected normalized chars."""
|
||||
return len(self._received) >= len(self._tts_normalized)
|
||||
|
||||
def reset(self):
|
||||
"""Reset received word accumulation without changing the expected text."""
|
||||
self._received = ""
|
||||
self._tts_pos = 0
|
||||
self._llm_pos = 0
|
||||
self._overflow_word = None
|
||||
self._llm_consumed = None
|
||||
self._frame_word = None
|
||||
53
src/pipecat/utils/text/word_timestamp_utils.py
Normal file
53
src/pipecat/utils/text/word_timestamp_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Utilities for normalizing word-timestamp streams from TTS services."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def merge_punct_tokens(
|
||||
word_times: list[tuple[str, float]],
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Merge punctuation/space-only tokens into the preceding word.
|
||||
|
||||
Some TTS services (e.g. Inworld) emit spaces and punctuation as separate
|
||||
word-timestamp tokens rather than attaching them to the adjacent word.
|
||||
This function collapses those tokens so downstream consumers always receive
|
||||
words with trailing punctuation already attached — identical to the format
|
||||
produced by ElevenLabs or Cartesia.
|
||||
|
||||
A token is considered punct/space-only when its text contains no alphanumeric
|
||||
characters after stripping XML/HTML tags. Such tokens are appended to the
|
||||
preceding word's text and their timestamp is discarded (the preceding word's
|
||||
timestamp is kept). Leading punct/space tokens with no preceding word are
|
||||
silently discarded. Every output token is stripped of leading and trailing
|
||||
whitespace (spaces, tabs, newlines).
|
||||
|
||||
Args:
|
||||
word_times: Raw list of ``(word, timestamp)`` pairs from the TTS service.
|
||||
|
||||
Returns:
|
||||
Merged list where every entry contains at least one alphanumeric character
|
||||
and has no leading or trailing whitespace.
|
||||
|
||||
Example::
|
||||
|
||||
merge_punct_tokens([("questions", 1.0), (", ", 1.2), ("explain", 1.4)])
|
||||
# → [("questions,", 1.0), ("explain", 1.4)]
|
||||
"""
|
||||
merged: list[tuple[str, float]] = []
|
||||
for word, ts in word_times:
|
||||
stripped = re.sub(r"<[^>]+>", "", word)
|
||||
has_alnum = any(c.isalnum() for c in stripped)
|
||||
if not has_alnum:
|
||||
if merged:
|
||||
prev_word, prev_ts = merged[-1]
|
||||
merged[-1] = (prev_word + word, prev_ts)
|
||||
# else: leading punct/space with no preceding word → discard
|
||||
else:
|
||||
merged.append((word, ts))
|
||||
return [(word.strip(), ts) for word, ts in merged]
|
||||
612
tests/test_aggregated_frame_sequencer.py
Normal file
612
tests/test_aggregated_frame_sequencer.py
Normal file
@@ -0,0 +1,612 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for AggregatedFrameSequencer.
|
||||
|
||||
All methods on the sequencer are synchronous and return lists of frames,
|
||||
so no async machinery is needed here.
|
||||
|
||||
Test groups:
|
||||
- register_skipped: immediate flush vs. blocked by a preceding spoken slot
|
||||
- register_spoken / complete_spoken_slot: push_text_frames=True path
|
||||
- flush: pts propagation, transport_destination, stops at incomplete spoken slot
|
||||
- process_word: normal, completing, passthrough, raw_text propagation
|
||||
- process_word overflow: single token spanning two slot boundaries
|
||||
- process_word force-complete via belongs_here failure
|
||||
- force_complete: remaining text emission, raw_text, corrupt raw discard, slot ordering
|
||||
- clear: resets all state
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import AggregatedTextFrame, AggregationType, TTSTextFrame
|
||||
from pipecat.utils.context.aggregated_frame_sequencer import AggregatedFrameSequencer
|
||||
from pipecat.utils.context.word_completion_tracker import WordCompletionTracker
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seq() -> AggregatedFrameSequencer:
|
||||
return AggregatedFrameSequencer(name="test")
|
||||
|
||||
|
||||
def _spoken_frame(text: str) -> AggregatedTextFrame:
|
||||
return AggregatedTextFrame(text, AggregationType.SENTENCE)
|
||||
|
||||
|
||||
def _skipped_frame(text: str) -> AggregatedTextFrame:
|
||||
return AggregatedTextFrame(text, "code")
|
||||
|
||||
|
||||
def _tracker(tts_text: str, llm_text: str | None = None) -> WordCompletionTracker:
|
||||
return WordCompletionTracker(tts_text, llm_text=llm_text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_skipped
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterSkipped(unittest.TestCase):
|
||||
def test_emits_immediately_with_empty_queue(self):
|
||||
seq = _seq()
|
||||
frame = _skipped_frame("code block")
|
||||
result = seq.register_skipped(frame, "ctx1", None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIs(result[0], frame)
|
||||
|
||||
def test_sets_append_to_context_true(self):
|
||||
seq = _seq()
|
||||
frame = _skipped_frame("code")
|
||||
seq.register_skipped(frame, "ctx1", None)
|
||||
self.assertTrue(frame.append_to_context)
|
||||
|
||||
def test_sets_context_id_on_frame(self):
|
||||
seq = _seq()
|
||||
frame = _skipped_frame("code")
|
||||
seq.register_skipped(frame, "ctx42", None)
|
||||
self.assertEqual(frame.context_id, "ctx42")
|
||||
|
||||
def test_sets_transport_destination(self):
|
||||
seq = _seq()
|
||||
frame = _skipped_frame("code")
|
||||
result = seq.register_skipped(frame, "ctx1", "dest-A")
|
||||
self.assertEqual(result[0].transport_destination, "dest-A")
|
||||
|
||||
def test_blocked_by_incomplete_spoken_slot(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
|
||||
result = seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_emits_immediately_after_already_complete_spoken_slot(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hi"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.complete_spoken_slot()
|
||||
result = seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_multiple_skipped_before_any_spoken_all_emit(self):
|
||||
seq = _seq()
|
||||
r1 = seq.register_skipped(_skipped_frame("code1"), "ctx1", None)
|
||||
r2 = seq.register_skipped(_skipped_frame("code2"), "ctx2", None)
|
||||
self.assertEqual(len(r1), 1)
|
||||
self.assertEqual(len(r2), 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_spoken / complete_spoken_slot (push_text_frames=True path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompleteSpokenSlot(unittest.TestCase):
|
||||
def test_noop_with_empty_queue(self):
|
||||
seq = _seq()
|
||||
self.assertEqual(seq.complete_spoken_slot(), [])
|
||||
|
||||
def test_marks_slot_complete_and_flushes_skipped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None) # blocked
|
||||
|
||||
result = seq.complete_spoken_slot()
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIs(result[0], skipped)
|
||||
self.assertTrue(skipped.append_to_context)
|
||||
|
||||
def test_only_first_pending_slot_is_marked(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("one"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.register_spoken(_spoken_frame("two"), "ctx2", tracker=None, append_to_context=True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None)
|
||||
|
||||
# ctx2 still blocks the skipped frame
|
||||
result = seq.complete_spoken_slot()
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_skipped_flushes_after_all_preceding_spoken_complete(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("one"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.register_spoken(_spoken_frame("two"), "ctx2", tracker=None, append_to_context=True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None)
|
||||
|
||||
seq.complete_spoken_slot() # completes ctx1
|
||||
result = seq.complete_spoken_slot() # completes ctx2 → flush skipped
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIs(result[0], skipped)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# flush
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlush(unittest.TestCase):
|
||||
def test_empty_queue_returns_empty(self):
|
||||
self.assertEqual(_seq().flush(), [])
|
||||
|
||||
def test_stops_at_incomplete_spoken_slot(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
self.assertEqual(seq.flush(), [])
|
||||
|
||||
def test_last_word_pts_assigned_to_skipped_frame(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
# process_word("hello") completes the spoken slot and calls flush(last_word_pts=77)
|
||||
result = seq.process_word("hello", pts=77, context_id="ctx1")
|
||||
flushed = [f for f in result if isinstance(f, AggregatedTextFrame) and f.text == "code"]
|
||||
self.assertEqual(len(flushed), 1)
|
||||
self.assertEqual(flushed[0].pts, 77)
|
||||
|
||||
def test_complete_spoken_slots_are_swept(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
|
||||
seq.complete_spoken_slot()
|
||||
# Queue should be empty after sweeping the complete spoken slot
|
||||
self.assertEqual(seq._slots, [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_word — basic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessWordBasic(unittest.TestCase):
|
||||
def _seq_with_spoken(self, text: str, ctx: str = "ctx1", append: bool = True):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame(text), ctx, _tracker(text), append)
|
||||
return seq
|
||||
|
||||
def test_returns_tts_text_frame(self):
|
||||
seq = self._seq_with_spoken("hello")
|
||||
result = seq.process_word("hello", pts=100, context_id="ctx1")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TTSTextFrame)
|
||||
|
||||
def test_frame_text_and_pts(self):
|
||||
seq = self._seq_with_spoken("hello")
|
||||
result = seq.process_word("hello", pts=100, context_id="ctx1")
|
||||
self.assertEqual(result[0].text, "hello")
|
||||
self.assertEqual(result[0].pts, 100)
|
||||
|
||||
def test_frame_context_id(self):
|
||||
seq = self._seq_with_spoken("hello", ctx="ctx99")
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx99")
|
||||
self.assertEqual(result[0].context_id, "ctx99")
|
||||
|
||||
def test_append_to_context_true(self):
|
||||
seq = self._seq_with_spoken("hello", append=True)
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx1")
|
||||
self.assertTrue(result[0].append_to_context)
|
||||
|
||||
def test_append_to_context_false(self):
|
||||
seq = self._seq_with_spoken("hello", append=False)
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx1")
|
||||
self.assertFalse(result[0].append_to_context)
|
||||
|
||||
def test_non_completing_word_does_not_flush_skipped(self):
|
||||
seq = self._seq_with_spoken("hello world")
|
||||
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
result = seq.process_word("hello", pts=10, context_id="ctx1")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TTSTextFrame)
|
||||
|
||||
def test_completing_word_flushes_blocked_skipped_frame(self):
|
||||
seq = self._seq_with_spoken("hello")
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
result = seq.process_word("hello", pts=50, context_id="ctx1")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertIsInstance(result[0], TTSTextFrame)
|
||||
self.assertIs(result[1], skipped)
|
||||
|
||||
def test_last_of_multiple_words_flushes_skipped(self):
|
||||
seq = self._seq_with_spoken("hello world")
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
seq.process_word("hello", pts=10, context_id="ctx1")
|
||||
result = seq.process_word("world", pts=20, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
def test_no_active_slot_emits_passthrough(self):
|
||||
seq = _seq()
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx-unknown")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsInstance(result[0], TTSTextFrame)
|
||||
self.assertEqual(result[0].text, "hello")
|
||||
self.assertEqual(result[0].context_id, "ctx-unknown")
|
||||
|
||||
def test_passthrough_uses_default_append_to_context_true(self):
|
||||
seq = _seq()
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx-unknown")
|
||||
self.assertTrue(result[0].append_to_context)
|
||||
|
||||
def test_unrecognised_word_emits_passthrough(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
|
||||
# "zzz" doesn't belong to "hello world" and there is no next slot
|
||||
result = seq.process_word("zzz", pts=5, context_id="ctx1")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].text, "zzz")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_word — raw_text propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessWordRawText(unittest.TestCase):
|
||||
def test_raw_text_split_across_word_frames(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(
|
||||
_spoken_frame("4111 1111"),
|
||||
"ctx1",
|
||||
WordCompletionTracker("4111 1111", llm_text="<card>4111 1111</card>"),
|
||||
append_to_context=True,
|
||||
)
|
||||
r1 = seq.process_word("4111", pts=10, context_id="ctx1")
|
||||
r2 = seq.process_word("1111", pts=20, context_id="ctx1")
|
||||
self.assertEqual(r1[0].raw_text, "<card>4111")
|
||||
last_word_frames = [f for f in r2 if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(last_word_frames[0].raw_text, "1111</card>")
|
||||
|
||||
def test_raw_text_none_when_no_llm_text(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx1")
|
||||
self.assertIsNone(result[0].raw_text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_word — overflow (single token spanning two slots)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessWordOverflow(unittest.TestCase):
|
||||
def test_overflow_produces_two_tts_text_frames(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
|
||||
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
|
||||
|
||||
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
|
||||
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(word_frames), 2)
|
||||
self.assertEqual(word_frames[0].text, "abc")
|
||||
self.assertEqual(word_frames[1].text, "def")
|
||||
|
||||
def test_overflow_assigns_correct_context_ids(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
|
||||
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
|
||||
|
||||
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
|
||||
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(word_frames[0].context_id, "ctx1")
|
||||
self.assertEqual(word_frames[1].context_id, "ctx2")
|
||||
|
||||
def test_overflow_completing_next_slot_flushes_skipped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
|
||||
seq.register_spoken(_spoken_frame("def"), "ctx2", _tracker("def"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None) # blocked behind ctx2
|
||||
|
||||
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
def test_overflow_not_completing_next_slot_does_not_flush_skipped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("abc"), "ctx1", _tracker("abc"), True)
|
||||
seq.register_spoken(_spoken_frame("def ghi"), "ctx2", _tracker("def ghi"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None)
|
||||
|
||||
# "abcdef" overflows: "def" goes to ctx2, but ctx2 still expects " ghi"
|
||||
result = seq.process_word("abcdef", pts=100, context_id="ctx1")
|
||||
self.assertFalse(any(f is skipped for f in result))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_word — force-complete via word_belongs_here failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessWordForcesComplete(unittest.TestCase):
|
||||
def test_word_for_next_slot_force_completes_current(self):
|
||||
"""When a word belongs to the next slot but not the current, the current
|
||||
slot is force-completed and the word is routed to the next slot."""
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
|
||||
|
||||
# "world" doesn't belong to ctx1 but belongs to ctx2
|
||||
result = seq.process_word("world", pts=50, context_id="ctx2")
|
||||
word_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
texts = {f.text for f in word_frames}
|
||||
self.assertIn("world", texts)
|
||||
|
||||
def test_force_complete_then_overflow_flushes_skipped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx3", None)
|
||||
|
||||
# "world" force-completes ctx1 and completes ctx2 via overflow
|
||||
result = seq.process_word("world", pts=50, context_id="ctx2")
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# force_complete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForceComplete(unittest.TestCase):
|
||||
def test_emits_remaining_text_when_word_dropped(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
|
||||
seq.process_word("hello", pts=10, context_id="ctx1") # "world" never arrives
|
||||
|
||||
result = seq.force_complete(last_word_pts=50)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "world")
|
||||
self.assertEqual(tts_frames[0].pts, 50)
|
||||
|
||||
def test_emits_full_text_when_no_words_arrived(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello world"), "ctx1", _tracker("hello world"), True)
|
||||
|
||||
result = seq.force_complete(last_word_pts=0)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "hello world")
|
||||
|
||||
def test_already_complete_slot_emits_nothing(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hi"), "ctx1", _tracker("hi"), True)
|
||||
seq.process_word("hi", pts=5, context_id="ctx1") # completes normally
|
||||
|
||||
result = seq.force_complete(last_word_pts=10)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_flushes_skipped_frames_after_completing(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
result = seq.force_complete(last_word_pts=20)
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
self.assertTrue(skipped.append_to_context)
|
||||
|
||||
def test_propagates_raw_text(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(
|
||||
_spoken_frame("4111 1111"),
|
||||
"ctx1",
|
||||
WordCompletionTracker("4111 1111", llm_text="<card>4111 1111</card>"),
|
||||
append_to_context=True,
|
||||
)
|
||||
seq.process_word("4111", pts=10, context_id="ctx1") # "1111" never arrives
|
||||
|
||||
result = seq.force_complete(last_word_pts=20)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(tts_frames[0].text, "1111")
|
||||
self.assertEqual(tts_frames[0].raw_text, "1111</card>")
|
||||
|
||||
def test_discards_corrupt_raw_remaining(self):
|
||||
"""raw_remaining is discarded when it does not contain remaining_text."""
|
||||
seq = _seq()
|
||||
# "abc" normalized ≠ "xyz" normalized — any remaining won't be in raw_remaining
|
||||
seq.register_spoken(
|
||||
_spoken_frame("abc"),
|
||||
"ctx1",
|
||||
WordCompletionTracker("abc", llm_text="xyz"),
|
||||
append_to_context=True,
|
||||
)
|
||||
result = seq.force_complete(last_word_pts=0)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "abc")
|
||||
self.assertIsNone(tts_frames[0].raw_text) # discarded due to corruption
|
||||
|
||||
def test_slot_without_tracker_just_marks_complete_and_flushes(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", tracker=None, append_to_context=True)
|
||||
skipped = _skipped_frame("code")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
result = seq.force_complete(last_word_pts=0)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(tts_frames, []) # no tracker → no word frame
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
def test_multiple_incomplete_slots_all_emitted(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.register_spoken(_spoken_frame("world"), "ctx2", _tracker("world"), True)
|
||||
|
||||
result = seq.force_complete(last_word_pts=0)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
texts = {f.text for f in tts_frames}
|
||||
self.assertIn("hello", texts)
|
||||
self.assertIn("world", texts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# clear
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClear(unittest.TestCase):
|
||||
def test_clears_slots(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.register_skipped(_skipped_frame("code"), "ctx2", None)
|
||||
seq.clear()
|
||||
self.assertEqual(seq._slots, [])
|
||||
|
||||
def test_clears_context_map(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.clear()
|
||||
self.assertEqual(seq._context_append_to_context, {})
|
||||
|
||||
def test_after_clear_skipped_emits_immediately(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.clear()
|
||||
frame = _skipped_frame("code")
|
||||
result = seq.register_skipped(frame, "ctx2", None)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_after_clear_process_word_uses_passthrough(self):
|
||||
seq = _seq()
|
||||
seq.register_spoken(_spoken_frame("hello"), "ctx1", _tracker("hello"), True)
|
||||
seq.clear()
|
||||
result = seq.process_word("hello", pts=1, context_id="ctx1")
|
||||
# No active slot after clear → passthrough
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].text, "hello")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CJK languages — Korean, Japanese, Chinese
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCJKLanguages(unittest.TestCase):
|
||||
"""Sequencer behaviour for CJK language scenarios.
|
||||
|
||||
Korean: Cartesia returns each word as a separate timestamp event (one word
|
||||
per process_word call). Japanese/Chinese: Cartesia merges all characters
|
||||
in one timestamp message into a single combined token before calling
|
||||
process_word.
|
||||
"""
|
||||
|
||||
# --- Korean ---
|
||||
|
||||
def test_korean_word_by_word_completes_slot_and_flushes_skipped(self):
|
||||
"""Korean words fed one at a time complete the spoken slot and unblock a skipped frame."""
|
||||
seq = _seq()
|
||||
sentence = "저는 여러분의 AI 어시스턴트입니다."
|
||||
words = ["저는", "여러분의", "AI", "어시스턴트입니다."]
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
skipped = _skipped_frame("[code]")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
# Skipped stays blocked until the last word arrives
|
||||
for word in words[:-1]:
|
||||
partial = seq.process_word(word, pts=100, context_id="ctx1")
|
||||
self.assertFalse(any(f is skipped for f in partial))
|
||||
|
||||
result = seq.process_word(words[-1], pts=200, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in result))
|
||||
|
||||
def test_korean_force_complete_emits_correct_remaining_text(self):
|
||||
"""After one Korean word, force_complete emits the correct unspoken suffix."""
|
||||
seq = _seq()
|
||||
sentence = "저는 여러분의 AI 어시스턴트입니다."
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
seq.process_word("저는", pts=10, context_id="ctx1")
|
||||
|
||||
result = seq.force_complete(last_word_pts=50)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "여러분의 AI 어시스턴트입니다.")
|
||||
self.assertEqual(tts_frames[0].pts, 50)
|
||||
|
||||
# --- Japanese ---
|
||||
|
||||
def test_japanese_combined_groups_complete_spoken_slot(self):
|
||||
"""Two Cartesia-style combined Japanese groups complete the slot and flush skipped."""
|
||||
seq = _seq()
|
||||
sentence = "こんにちは、私はあなたの"
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
skipped = _skipped_frame("[skipped]")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
r1 = seq.process_word("こんにちは、私", pts=100, context_id="ctx1")
|
||||
self.assertFalse(any(f is skipped for f in r1))
|
||||
|
||||
r2 = seq.process_word("はあなたの", pts=200, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in r2))
|
||||
|
||||
def test_japanese_force_complete_emits_remaining_chars(self):
|
||||
"""After the first Japanese combined group, force_complete emits the rest."""
|
||||
seq = _seq()
|
||||
sentence = "こんにちは、私はあなたの"
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
seq.process_word("こんにちは、私", pts=10, context_id="ctx1")
|
||||
|
||||
result = seq.force_complete(last_word_pts=50)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "はあなたの")
|
||||
|
||||
# --- Chinese ---
|
||||
|
||||
def test_chinese_combined_groups_complete_spoken_slot(self):
|
||||
"""Two Cartesia-style combined Chinese groups complete the slot and flush skipped."""
|
||||
seq = _seq()
|
||||
sentence = "你好,我是你的智能"
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
skipped = _skipped_frame("[skipped]")
|
||||
seq.register_skipped(skipped, "ctx2", None)
|
||||
|
||||
r1 = seq.process_word("你好,我是", pts=100, context_id="ctx1")
|
||||
self.assertFalse(any(f is skipped for f in r1))
|
||||
|
||||
r2 = seq.process_word("你的智能", pts=200, context_id="ctx1")
|
||||
self.assertTrue(any(f is skipped for f in r2))
|
||||
|
||||
def test_chinese_force_complete_emits_remaining_chars(self):
|
||||
"""After the first Chinese combined group, force_complete emits the rest."""
|
||||
seq = _seq()
|
||||
sentence = "你好,我是你的智能"
|
||||
seq.register_spoken(_spoken_frame(sentence), "ctx1", _tracker(sentence), True)
|
||||
seq.process_word("你好,我是", pts=10, context_id="ctx1")
|
||||
|
||||
result = seq.force_complete(last_word_pts=50)
|
||||
tts_frames = [f for f in result if isinstance(f, TTSTextFrame)]
|
||||
self.assertEqual(len(tts_frames), 1)
|
||||
self.assertEqual(tts_frames[0].text, "你的智能")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -18,7 +18,7 @@ def _service(language: str) -> CartesiaTTSService:
|
||||
def _process_word_timestamps(
|
||||
words: list[str], starts: list[float], language: str
|
||||
) -> list[tuple[str, float]]:
|
||||
return _service(language)._process_word_timestamps_for_language(words, starts)
|
||||
return _service(language)._normalize_word_timestamps(words, starts)
|
||||
|
||||
|
||||
def _concatenate_processed_timestamps(
|
||||
@@ -27,7 +27,7 @@ def _concatenate_processed_timestamps(
|
||||
service = _service(language)
|
||||
text_parts = []
|
||||
for words, starts in timestamp_groups:
|
||||
processed_timestamps = service._process_word_timestamps_for_language(words, starts)
|
||||
processed_timestamps = service._normalize_word_timestamps(words, starts)
|
||||
includes_inter_frame_spaces = service._word_timestamps_include_inter_frame_spaces()
|
||||
text_parts.extend(
|
||||
TextPartForConcatenation(
|
||||
|
||||
@@ -21,6 +21,13 @@ repeated for each TTSSpeakFrame, with no cross-group contamination.
|
||||
Also covers LLM response flow with push_text_frames=True (non-word-timestamp TTS):
|
||||
verifies TTSTextFrame ordering relative to LLMFullResponseEndFrame.
|
||||
|
||||
Also covers smart-text / WordCompletionTracker features:
|
||||
- Skipped frames (skip_aggregator_types) held until preceding spoken slots complete.
|
||||
- raw_text on AggregatedTextFrame propagated as spans to TTSTextFrames.
|
||||
- Overflow: a single TTS word straddling two AggregatedTextFrame boundaries produces
|
||||
two correctly-attributed TTSTextFrames.
|
||||
- Force-complete safety net: skipped frames flush even when TTS drops word timestamps.
|
||||
|
||||
Also covers the interruption-during-pause deadlock scenario (see test_no_deadlock_on_interrupt_*).
|
||||
"""
|
||||
|
||||
@@ -50,6 +57,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test-only frame
|
||||
@@ -422,7 +430,7 @@ def _assert_group_ordering(
|
||||
# All frames between TTSStartedFrame and TTSStoppedFrame must be audio.
|
||||
mid_types = types[started_idx + 1 : stopped_idx]
|
||||
for t in mid_types:
|
||||
assert t is TTSAudioRawFrame, (
|
||||
assert t in (TTSAudioRawFrame, TTSTextFrame), (
|
||||
f"Group {foo_label!r}: unexpected frame {t.__name__!r} between "
|
||||
f"TTSStartedFrame and TTSStoppedFrame. Got: {type_names}"
|
||||
)
|
||||
@@ -551,7 +559,7 @@ async def test_http_push_text_llm_response_end_after_tts_text():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_word_timestamps_verbatim_tokens():
|
||||
"""HTTP path: text, PTS order, flag, and text-before-audio are all verified.
|
||||
"""HTTP path: text, PTS order, and text-before-audio are all verified.
|
||||
|
||||
Word timestamps arrive in the audio context queue before the audio frame.
|
||||
_handle_audio_context caches them, then flushes when the first audio frame
|
||||
@@ -572,7 +580,6 @@ async def test_http_word_timestamps_verbatim_tokens():
|
||||
audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)]
|
||||
|
||||
assert [f.text for f in tts_text_frames] == ["hello", "world"]
|
||||
assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames)
|
||||
|
||||
pts_values = [f.pts for f in tts_text_frames]
|
||||
assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), (
|
||||
@@ -590,15 +597,14 @@ async def test_http_word_timestamps_verbatim_tokens():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_word_timestamps_punctuation_tokens():
|
||||
"""Verbatim punctuation tokens are preserved with flag=True; default flag is False.
|
||||
"""Punct-only tokens are merged into the preceding word when includes_inter_frame_spaces=True.
|
||||
|
||||
Models the Inworld API scenario: the TTS returns tokens exactly as sent.
|
||||
Space placement rule:
|
||||
- word-follows-word: space is the leading char of the next word (e.g. " world")
|
||||
- word-follows-punctuation: space is the trailing char of the punctuation token
|
||||
(e.g. "! "), so the following word token carries no leading space.
|
||||
The flag must reach every frame and the text must not be modified.
|
||||
Also acts as a regression guard that flag=False is the default.
|
||||
Models the Inworld API scenario: the TTS returns separate space and punctuation
|
||||
tokens. add_word_timestamps calls merge_punct_tokens when includes_inter_frame_spaces
|
||||
is True, collapsing those tokens into the preceding word before the tracker sees them.
|
||||
|
||||
With flag=False (default) tokens are forwarded as-is; the tracker strips leading/
|
||||
trailing whitespace from each frame word via get_word_for_frame().
|
||||
"""
|
||||
verbatim_tokens = [
|
||||
("hello", 0.0),
|
||||
@@ -609,9 +615,9 @@ async def test_http_word_timestamps_punctuation_tokens():
|
||||
(" you", 0.75),
|
||||
("?", 0.9),
|
||||
]
|
||||
expected_texts = ["hello", " world", "! ", "How", " are", " you", "?"]
|
||||
|
||||
# With flag=True: all tokens verbatim, all frames carry the flag.
|
||||
# With flag=True: punct-only tokens ("! " and "?") are merged into the preceding
|
||||
# words (" world" → " world! " and " you" → " you?"), then stripped by the tracker.
|
||||
tts_ifs = _MockWordTimestampHttpTTSService(
|
||||
includes_inter_frame_spaces=True,
|
||||
word_times=verbatim_tokens,
|
||||
@@ -621,12 +627,11 @@ async def test_http_word_timestamps_punctuation_tokens():
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
|
||||
)
|
||||
text_frames_ifs = [f for f in frames_ifs[0] if isinstance(f, TTSTextFrame)]
|
||||
assert [f.text for f in text_frames_ifs] == expected_texts, (
|
||||
"Verbatim tokens must not be modified"
|
||||
assert [f.text for f in text_frames_ifs] == ["hello", "world!", "How", "are", "you?"], (
|
||||
"Punct-only tokens must be merged into the preceding word"
|
||||
)
|
||||
assert all(f.includes_inter_frame_spaces is True for f in text_frames_ifs)
|
||||
|
||||
# With flag=False (default): same tokens, flag must be False on every frame.
|
||||
# With flag=False (default): no merging; tracker strips leading/trailing spaces.
|
||||
tts_plain = _MockWordTimestampHttpTTSService(
|
||||
word_times=verbatim_tokens,
|
||||
)
|
||||
@@ -635,13 +640,12 @@ async def test_http_word_timestamps_punctuation_tokens():
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
|
||||
)
|
||||
text_frames_plain = [f for f in frames_plain[0] if isinstance(f, TTSTextFrame)]
|
||||
assert [f.text for f in text_frames_plain] == expected_texts
|
||||
assert all(f.includes_inter_frame_spaces is False for f in text_frames_plain)
|
||||
assert [f.text for f in text_frames_plain] == ["hello", "world", "!", "How", "are", "you", "?"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_word_timestamps_verbatim_tokens():
|
||||
"""WebSocket path: _WordTimestampEntry carries verbatim text, PTS, and flag.
|
||||
"""WebSocket path: text, PTS order, and text-before-audio are all verified.
|
||||
|
||||
Unlike the HTTP path the word timestamps are sent asynchronously from a
|
||||
background task. They arrive before the audio frame and are cached until
|
||||
@@ -662,7 +666,6 @@ async def test_websocket_word_timestamps_verbatim_tokens():
|
||||
audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)]
|
||||
|
||||
assert [f.text for f in tts_text_frames] == ["hello", "world"]
|
||||
assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames)
|
||||
|
||||
pts_values = [f.pts for f in tts_text_frames]
|
||||
assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), (
|
||||
@@ -678,7 +681,7 @@ async def test_websocket_word_timestamps_verbatim_tokens():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_word_timestamps_punctuation_tokens():
|
||||
"""WebSocket path: verbatim punctuation tokens reach TTSTextFrame unchanged."""
|
||||
"""WebSocket path: punct-only tokens are merged into the preceding word."""
|
||||
verbatim_tokens = [
|
||||
("hello", 0.0),
|
||||
(" world", 0.15),
|
||||
@@ -697,10 +700,443 @@ async def test_websocket_word_timestamps_punctuation_tokens():
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
|
||||
)
|
||||
text_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
assert [f.text for f in text_frames] == ["hello", " world", "! ", "How", " are", " you", "?"], (
|
||||
"Verbatim tokens must not be modified"
|
||||
assert [f.text for f in text_frames] == ["hello", "world!", "How", "are", "you?"], (
|
||||
"Punct-only tokens must be merged into the preceding word"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call word-timestamp mock (for overflow tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MockPerCallWordTimestampHttpTTSService(TTSService):
|
||||
"""HTTP-style TTS where each run_tts() call consumes its own word-time list.
|
||||
|
||||
Designed for tests that need different word tokens per sentence. The
|
||||
``word_times_per_call`` list is consumed in order; an empty inner list means
|
||||
no word-timestamp events are emitted for that call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
word_times_per_call: list[list[tuple[str, float]]],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
push_text_frames=False,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
self._word_times_queue = list(word_times_per_call)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
word_times = self._word_times_queue.pop(0) if self._word_times_queue else []
|
||||
if word_times:
|
||||
await self.add_word_timestamps(word_times, context_id=context_id)
|
||||
yield TTSAudioRawFrame(
|
||||
audio=_FAKE_AUDIO,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: skipped frame ordering (skip_aggregator_types)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_skipped_frame_waits_for_spoken_words():
|
||||
"""Skipped frames are held until the preceding spoken slot's word timestamps
|
||||
are all processed, then flushed in order (HTTP / synchronous audio path).
|
||||
|
||||
Sequence sent:
|
||||
AggregatedTextFrame("hello world", SENTENCE) — spoken; yields 2 TTSTextFrames
|
||||
AggregatedTextFrame("some code", "code") — in skip_aggregator_types; must wait
|
||||
|
||||
Expected downstream order:
|
||||
TTSTextFrame("hello")
|
||||
TTSTextFrame("world")
|
||||
AggregatedTextFrame("some code", append_to_context=True)
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService(skip_aggregator_types=["code"])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
|
||||
AggregatedTextFrame("some code", "code"),
|
||||
],
|
||||
)
|
||||
down = frames_received[0]
|
||||
|
||||
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
|
||||
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
|
||||
|
||||
assert [f.text for f in word_frames] == ["hello", "world"]
|
||||
assert len(skipped) == 1
|
||||
assert skipped[0].append_to_context is True
|
||||
|
||||
last_word_idx = max(down.index(f) for f in word_frames)
|
||||
skipped_idx = down.index(skipped[0])
|
||||
assert skipped_idx > last_word_idx, (
|
||||
f"Skipped frame (pos {skipped_idx}) must appear after last word frame (pos {last_word_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_skipped_frame_waits_for_spoken_words():
|
||||
"""Same ordering guarantee on the WebSocket / async audio delivery path.
|
||||
|
||||
Because audio is delivered from a background task after asyncio.sleep(), the
|
||||
skipped frame arrives at _push_frame_respecting_previous_aggregated_frame
|
||||
*before* the spoken slot's word timestamps have been processed, directly
|
||||
exercising the hold-and-flush path.
|
||||
"""
|
||||
tts = _MockWordTimestampWSTTSService(skip_aggregator_types=["code"])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
|
||||
AggregatedTextFrame("some code", "code"),
|
||||
],
|
||||
)
|
||||
down = frames_received[0]
|
||||
|
||||
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
|
||||
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
|
||||
|
||||
assert [f.text for f in word_frames] == ["hello", "world"]
|
||||
assert len(skipped) == 1
|
||||
assert skipped[0].append_to_context is True
|
||||
|
||||
last_word_idx = max(down.index(f) for f in word_frames)
|
||||
skipped_idx = down.index(skipped[0])
|
||||
assert skipped_idx > last_word_idx, (
|
||||
f"Skipped frame (pos {skipped_idx}) must appear after last word frame (pos {last_word_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skipped_frame_before_spoken_emits_immediately():
|
||||
"""A skipped frame with no preceding spoken slot is emitted right away.
|
||||
|
||||
Sequence:
|
||||
AggregatedTextFrame("some code", "code") — no spoken slot before it → emits now
|
||||
AggregatedTextFrame("hello world", SENTENCE) — spoken; TTSTextFrames follow
|
||||
|
||||
Expected: AggregatedTextFrame("some code") appears *before* TTSTextFrame("hello").
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService(skip_aggregator_types=["code"])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("some code", "code"),
|
||||
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
|
||||
],
|
||||
)
|
||||
down = frames_received[0]
|
||||
|
||||
word_frames = [f for f in down if isinstance(f, TTSTextFrame)]
|
||||
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
|
||||
|
||||
assert len(skipped) == 1
|
||||
assert skipped[0].append_to_context is True
|
||||
assert len(word_frames) >= 1
|
||||
|
||||
skipped_idx = down.index(skipped[0])
|
||||
first_word_idx = down.index(word_frames[0])
|
||||
assert skipped_idx < first_word_idx, (
|
||||
f"Skipped frame (pos {skipped_idx}) must appear before first word frame (pos {first_word_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skipped_frame_flushed_when_word_timestamps_incomplete():
|
||||
"""Force-complete path: skipped frame still emits when the TTS drops word timestamps.
|
||||
|
||||
Only one of the two expected tokens ("hello") is returned. The spoken slot never
|
||||
reaches its expected character count through the normal path. When
|
||||
on_audio_context_done fires it force-completes any remaining spoken slots and
|
||||
flushes the waiting skipped frame.
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService(
|
||||
word_times=[("hello", 0.0)], # "world" is never sent
|
||||
skip_aggregator_types=["code"],
|
||||
)
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("hello world", AggregationType.SENTENCE),
|
||||
AggregatedTextFrame("some code", "code"),
|
||||
],
|
||||
)
|
||||
down = frames_received[0]
|
||||
|
||||
skipped = [f for f in down if isinstance(f, AggregatedTextFrame) and f.text == "some code"]
|
||||
assert len(skipped) == 1, "Skipped frame must be flushed via force-complete safety net"
|
||||
assert skipped[0].append_to_context is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: raw_text propagation through WordCompletionTracker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_text_propagated_to_tts_text_frames():
|
||||
"""raw_text on AggregatedTextFrame is split across TTSTextFrames by the tracker.
|
||||
|
||||
The frame carries raw_text="<card>4111 1111</card>" while the TTS-prepared
|
||||
text is "4111 1111". The WordCompletionTracker advances a cursor through the
|
||||
raw text in step with incoming word tokens, so each TTSTextFrame receives the
|
||||
exact raw span it represents.
|
||||
|
||||
Expected (trailing whitespace stripped because includes_inter_frame_spaces=False):
|
||||
TTSTextFrame("4111").raw_text == "<card>4111"
|
||||
TTSTextFrame("1111").raw_text == "1111</card>"
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService()
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame(
|
||||
"4111 1111", AggregationType.SENTENCE, raw_text="<card>4111 1111</card>"
|
||||
)
|
||||
],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["4111", "1111"]
|
||||
# get_raw_consumed() strips trailing whitespace when includes_inter_frame_spaces=False
|
||||
assert word_frames[0].raw_text == "<card>4111"
|
||||
assert word_frames[1].raw_text == "1111</card>"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: overflow — TTS word spanning two AggregatedTextFrame boundaries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overflow_word_spanning_two_aggregated_frames():
|
||||
"""A single TTS token straddling two AggregatedTextFrame boundaries produces
|
||||
two correctly-attributed TTSTextFrames.
|
||||
|
||||
Setup:
|
||||
Frame 1: AggregatedTextFrame("abc", SENTENCE)
|
||||
Frame 2: AggregatedTextFrame("def", SENTENCE)
|
||||
|
||||
The TTS for frame 1 returns the single token "abcdef", which overshoots
|
||||
frame 1 by three characters. _emit_overflow_word splits it:
|
||||
TTSTextFrame("abc") — frame 1's portion (context_id = ctx1)
|
||||
TTSTextFrame("def") — overflow attributed to frame 2 (context_id = ctx2)
|
||||
|
||||
Frame 2 receives no word-timestamp events because the overflow already
|
||||
consumed its expected text.
|
||||
"""
|
||||
tts = _MockPerCallWordTimestampHttpTTSService(
|
||||
word_times_per_call=[
|
||||
[("abcdef", 0.0)], # frame 1: single token spanning both frames
|
||||
[], # frame 2: no word timestamps (overflow already covered it)
|
||||
]
|
||||
)
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame("abc", AggregationType.SENTENCE),
|
||||
AggregatedTextFrame("def", AggregationType.SENTENCE),
|
||||
],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["abc", "def"], (
|
||||
f"Expected ['abc', 'def'] but got {[f.text for f in word_frames]}"
|
||||
)
|
||||
assert word_frames[0].context_id != word_frames[1].context_id, (
|
||||
"Overflow TTSTextFrame must carry frame 2's context_id, not frame 1's"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call word-timestamp mock for WebSocket path (for force-complete tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MockPerCallWordTimestampWSTTSService(TTSService):
|
||||
"""WebSocket-style TTS where each run_tts() call consumes its own word-time list.
|
||||
|
||||
Mirrors _MockPerCallWordTimestampHttpTTSService but uses the async audio-context
|
||||
delivery pattern so it exercises _handle_audio_context (the WebSocket path).
|
||||
An empty inner list means no word-timestamp events are emitted for that call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
word_times_per_call: list[list[tuple[str, float]]],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=False,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
self._word_times_queue = list(word_times_per_call)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return False
|
||||
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
word_times = self._word_times_queue.pop(0) if self._word_times_queue else []
|
||||
|
||||
async def _deliver():
|
||||
await asyncio.sleep(0.01)
|
||||
if word_times:
|
||||
await self.add_word_timestamps(word_times, context_id=context_id)
|
||||
await self.append_to_audio_context(
|
||||
context_id,
|
||||
TTSAudioRawFrame(
|
||||
audio=_FAKE_AUDIO,
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
),
|
||||
)
|
||||
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
|
||||
await self.remove_audio_context(context_id)
|
||||
|
||||
self.create_task(_deliver(), name=f"mock_ws_per_call_deliver_{context_id}")
|
||||
if False:
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _force_complete_spoken_slots — TTSTextFrame emission for dropped timestamps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_force_complete_partial_timestamps_emits_remaining_text():
|
||||
"""_force_complete_spoken_slots emits a TTSTextFrame for the unspoken word suffix.
|
||||
|
||||
Only the first token ("hello") is delivered as a word-timestamp event; "world"
|
||||
is never sent. When the audio context ends _force_complete_spoken_slots fires,
|
||||
reads get_remaining_text() from the tracker, and emits TTSTextFrame("world").
|
||||
|
||||
Expected TTSTextFrames in order: ["hello", "world"].
|
||||
"""
|
||||
tts = _MockWordTimestampHttpTTSService(word_times=[("hello", 0.0)])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["hello", "world"], (
|
||||
f"Expected ['hello', 'world'] but got {[f.text for f in word_frames]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_force_complete_no_timestamps_emits_full_text():
|
||||
"""_force_complete_spoken_slots emits the full text when no word timestamps arrive.
|
||||
|
||||
No word-timestamp events are sent for "hello world". The slot remains incomplete
|
||||
when the audio context ends; force-complete reads the full remaining text from the
|
||||
tracker and emits TTSTextFrame("hello world").
|
||||
"""
|
||||
tts = _MockPerCallWordTimestampHttpTTSService(word_times_per_call=[[]])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert len(word_frames) == 1, (
|
||||
f"Expected exactly 1 TTSTextFrame, got {len(word_frames)}: {[f.text for f in word_frames]}"
|
||||
)
|
||||
assert word_frames[0].text == "hello world", (
|
||||
f"Expected TTSTextFrame('hello world'), got {word_frames[0].text!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_force_complete_raw_text_propagated():
|
||||
"""force-complete carries the correct raw_text span on the emitted TTSTextFrame.
|
||||
|
||||
AggregatedTextFrame carries raw_text="<card>4111 1111</card>". Only "4111" arrives
|
||||
as a word-timestamp; "1111" is force-completed.
|
||||
|
||||
Expected:
|
||||
TTSTextFrame("4111").raw_text == "<card>4111" — from normal word path
|
||||
TTSTextFrame("1111").raw_text == "1111</card>" — from force-complete path
|
||||
"""
|
||||
tts = _MockPerCallWordTimestampHttpTTSService(word_times_per_call=[[("4111", 0.0)]])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[
|
||||
AggregatedTextFrame(
|
||||
"4111 1111", AggregationType.SENTENCE, raw_text="<card>4111 1111</card>"
|
||||
)
|
||||
],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["4111", "1111"], (
|
||||
f"Expected ['4111', '1111'] but got {[f.text for f in word_frames]}"
|
||||
)
|
||||
assert word_frames[0].raw_text == "<card>4111", (
|
||||
f"Expected raw_text '<card>4111' on first frame, got {word_frames[0].raw_text!r}"
|
||||
)
|
||||
assert word_frames[1].raw_text == "1111</card>", (
|
||||
f"Expected raw_text '1111</card>' on force-complete frame, got {word_frames[1].raw_text!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_force_complete_partial_timestamps_emits_remaining_text():
|
||||
"""WebSocket path: _force_complete_spoken_slots emits TTSTextFrame for dropped token.
|
||||
|
||||
Mirrors test_http_force_complete_partial_timestamps_emits_remaining_text on the
|
||||
async audio delivery path to confirm force-complete fires correctly from
|
||||
_handle_audio_context when TTSStoppedFrame arrives before all word timestamps.
|
||||
"""
|
||||
tts = _MockWordTimestampWSTTSService(word_times=[("hello", 0.0)])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert [f.text for f in word_frames] == ["hello", "world"], (
|
||||
f"Expected ['hello', 'world'] but got {[f.text for f in word_frames]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_force_complete_no_timestamps_emits_full_text():
|
||||
"""WebSocket path: full text emitted as single TTSTextFrame when no timestamps arrive."""
|
||||
tts = _MockPerCallWordTimestampWSTTSService(word_times_per_call=[[]])
|
||||
frames_received = await run_test(
|
||||
tts,
|
||||
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
|
||||
)
|
||||
word_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
|
||||
|
||||
assert len(word_frames) == 1, (
|
||||
f"Expected exactly 1 TTSTextFrame, got {len(word_frames)}: {[f.text for f in word_frames]}"
|
||||
)
|
||||
assert word_frames[0].text == "hello world", (
|
||||
f"Expected TTSTextFrame('hello world'), got {word_frames[0].text!r}"
|
||||
)
|
||||
assert all(f.includes_inter_frame_spaces is True for f in text_frames)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
1602
tests/test_word_completion_tracker.py
Normal file
1602
tests/test_word_completion_tracker.py
Normal file
File diff suppressed because it is too large
Load Diff
90
tests/test_word_timestamp_utils.py
Normal file
90
tests/test_word_timestamp_utils.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.utils.text.word_timestamp_utils import merge_punct_tokens
|
||||
|
||||
|
||||
class TestMergePunctTokens(unittest.TestCase):
|
||||
def test_empty_list(self):
|
||||
self.assertEqual(merge_punct_tokens([]), [])
|
||||
|
||||
def test_all_alnum_words_pass_through(self):
|
||||
input = [("hello", 0.0), ("world", 1.0)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("hello", 0.0), ("world", 1.0)])
|
||||
|
||||
def test_trailing_space_merged_and_stripped(self):
|
||||
input = [("I", 0.0), (" ", 0.2)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("I", 0.0)])
|
||||
|
||||
def test_comma_space_merged_and_stripped(self):
|
||||
input = [("questions", 1.0), (", ", 1.2), ("explain", 1.4)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("questions,", 1.0), ("explain", 1.4)])
|
||||
|
||||
def test_leading_space_with_no_preceding_word_discarded(self):
|
||||
input = [(" ", 0.0), ("hello", 0.5)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("hello", 0.5)])
|
||||
|
||||
def test_leading_empty_string_discarded(self):
|
||||
input = [("", 0.0), ("hello", 0.5)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("hello", 0.5)])
|
||||
|
||||
def test_multiple_consecutive_punct_tokens_merged_and_stripped(self):
|
||||
input = [("word", 0.0), (",", 0.1), (" ", 0.2), ("next", 0.3)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("word,", 0.0), ("next", 0.3)])
|
||||
|
||||
def test_timestamp_of_preceding_word_is_kept(self):
|
||||
"""Merged punct tokens adopt the preceding word's timestamp."""
|
||||
input = [("hello", 2.5), (",", 2.7)]
|
||||
result = merge_punct_tokens(input)
|
||||
self.assertEqual(result, [("hello,", 2.5)])
|
||||
|
||||
def test_xml_tag_only_token_is_treated_as_punct(self):
|
||||
"""A token that is only an XML tag (no alnum chars) merges into the preceding word."""
|
||||
input = [("word", 0.0), ("<break/>", 0.1), ("next", 0.3)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("word<break/>", 0.0), ("next", 0.3)])
|
||||
|
||||
def test_xml_tag_with_alnum_content_passes_through(self):
|
||||
"""A token like '<spell>123</spell>' has alnum chars after stripping tags."""
|
||||
input = [("<spell>123</spell>", 0.0), ("and", 0.5)]
|
||||
self.assertEqual(merge_punct_tokens(input), [("<spell>123</spell>", 0.0), ("and", 0.5)])
|
||||
|
||||
def test_inworld_style_full_stream(self):
|
||||
"""Full Inworld-style raw stream produces expected merged and stripped output."""
|
||||
raw = [
|
||||
("", 0.0),
|
||||
("I", 0.1),
|
||||
(" ", 0.2),
|
||||
("can", 0.3),
|
||||
(" ", 0.4),
|
||||
("answer", 0.5),
|
||||
(" ", 0.6),
|
||||
("questions", 0.7),
|
||||
(", ", 0.8),
|
||||
("explain", 0.9),
|
||||
(" ", 1.0),
|
||||
("things", 1.1),
|
||||
(".", 1.2),
|
||||
]
|
||||
expected = [
|
||||
("I", 0.1),
|
||||
("can", 0.3),
|
||||
("answer", 0.5),
|
||||
("questions,", 0.7),
|
||||
("explain", 0.9),
|
||||
("things.", 1.1),
|
||||
]
|
||||
self.assertEqual(merge_punct_tokens(raw), expected)
|
||||
|
||||
def test_only_punct_tokens_returns_empty(self):
|
||||
"""A list containing only punct/space tokens produces an empty result."""
|
||||
input = [(" ", 0.0), (",", 0.1), (".", 0.2)]
|
||||
self.assertEqual(merge_punct_tokens(input), [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user