feat(tts): add includes_inter_frame_spaces flag to word-timestamp API

Some TTS providers (e.g. Inworld) return verbatim tokens where spaces and
punctuation are already embedded in the token text. When downstream consumers
join these tokens with an extra space they produce "hello , world" instead of
"hello, world".

Add an opt-in `includes_inter_frame_spaces: bool = False` parameter to
`add_word_timestamps` / `_add_word_timestamps`. The flag is threaded through
`_WordTimestampEntry` and stamped onto every emitted `TTSTextFrame`.
Defaults to `False` — no behaviour change for existing services.

`InworldTTSService` passes `includes_inter_frame_spaces=True` and stops
pre-processing tokens in `_calculate_word_times`, returning them verbatim.

Tests added to `test_tts_frame_ordering.py` covering both HTTP and WebSocket
delivery paths: verbatim text preservation, PTS ordering, text-before-audio
ordering, and the Inworld punctuation-token scenario.

Made-with: Cursor
This commit is contained in:
Ian Lee
2026-04-18 12:03:32 -07:00
parent fc1c3b48dc
commit b435ddfa44
4 changed files with 291 additions and 12 deletions

View File

@@ -0,0 +1,2 @@
- added `includes_inter_frame_spaces` to `TTSService` and `WebsocketTTSService`, with default to `false` to preserve existing behaviors.
- `InworldTTSService` to set `includes_inter_frame_spaces` to `true` to stop inserting spaces between words from server, which already include spaces.

View File

@@ -410,7 +410,9 @@ class InworldHttpTTSService(TTSService):
word_times, chunk_end_time = self._calculate_word_times(timestamp_info)
if word_times:
self._current_run_had_timestamps = True
await self.add_word_timestamps(word_times, context_id)
await self.add_word_timestamps(
word_times, context_id, includes_inter_frame_spaces=True
)
# Track the maximum end time across all chunks
utterance_duration = max(utterance_duration, chunk_end_time)
@@ -447,7 +449,9 @@ class InworldHttpTTSService(TTSService):
word_times, chunk_end_time = self._calculate_word_times(timestamp_info)
if word_times:
self._current_run_had_timestamps = True
await self.add_word_timestamps(word_times, context_id)
await self.add_word_timestamps(
word_times, context_id, includes_inter_frame_spaces=True
)
utterance_duration = chunk_end_time
audio_data = base64.b64decode(response_data["audioContent"])
@@ -1013,7 +1017,9 @@ class InworldTTSService(WebsocketTTSService):
if word_times:
if ctx_id:
self._contexts_with_timestamps.add(ctx_id)
await self.add_word_timestamps(word_times, ctx_id)
await self.add_word_timestamps(
word_times, ctx_id, includes_inter_frame_spaces=True
)
# Handle flush completion, which indicates the end of a generation
if "flushCompleted" in result:

View File

@@ -96,6 +96,7 @@ class _WordTimestampEntry:
word: str
timestamp: float
context_id: str
includes_inter_frame_spaces: bool = False
class TTSService(AIService):
@@ -1049,8 +1050,10 @@ class TTSService(AIService):
if self._initial_word_times:
cached = self._initial_word_times.copy()
self._initial_word_times = []
for word, timestamp_seconds, ctx_id in cached:
await self._add_word_timestamps([(word, timestamp_seconds)], ctx_id)
for word, timestamp_seconds, ctx_id, ifs in cached:
await self._add_word_timestamps(
[(word, timestamp_seconds)], ctx_id, includes_inter_frame_spaces=ifs
)
async def reset_word_timestamps(self):
"""Reset word timestamp tracking."""
@@ -1060,7 +1063,10 @@ class TTSService(AIService):
self._initial_word_times = []
async def add_word_timestamps(
self, word_times: list[tuple[str, float]], context_id: str | None = None
self,
word_times: list[tuple[str, float]],
context_id: str | None = None,
includes_inter_frame_spaces: bool = False,
):
"""Add word timestamps for processing.
@@ -1072,6 +1078,9 @@ class TTSService(AIService):
Args:
word_times: List of (word, timestamp) tuples where timestamp is in seconds.
context_id: Unique identifier for the TTS context.
includes_inter_frame_spaces: When True, the tokens already embed inter-word
spacing (spaces and punctuation are part of the token text). Downstream
consumers must not inject additional spaces between tokens.
"""
if context_id and self.audio_context_available(context_id):
for word, timestamp in word_times:
@@ -1081,13 +1090,21 @@ class TTSService(AIService):
word=word,
timestamp=timestamp,
context_id=context_id,
includes_inter_frame_spaces=includes_inter_frame_spaces,
),
)
else:
await self._add_word_timestamps(word_times=word_times, context_id=context_id)
await self._add_word_timestamps(
word_times=word_times,
context_id=context_id,
includes_inter_frame_spaces=includes_inter_frame_spaces,
)
async def _add_word_timestamps(
self, word_times: list[tuple[str, float]], context_id: str | None = None
self,
word_times: list[tuple[str, float]],
context_id: str | None = None,
includes_inter_frame_spaces: bool = False,
):
"""Process word timestamps directly, building and pushing TTSTextFrames inline.
@@ -1103,11 +1120,12 @@ class TTSService(AIService):
ts_ns = seconds_to_nanoseconds(timestamp)
if self._initial_word_timestamp == -1:
# Cache until we have audio and can compute PTS.
self._initial_word_times.append((word, timestamp, context_id))
self._initial_word_times.append(
(word, timestamp, context_id, includes_inter_frame_spaces)
)
else:
# Assumption: word-by-word text frames don't include spaces, so
# we can rely on the default includes_inter_frame_spaces=False
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
frame.includes_inter_frame_spaces = includes_inter_frame_spaces
frame.pts = self._initial_word_timestamp + ts_ns
frame.context_id = context_id
if context_id in self._tts_contexts:
@@ -1310,7 +1328,9 @@ class TTSService(AIService):
# Route word timestamps through _add_word_timestamps so they are
# processed in playback order alongside audio frames.
await self._add_word_timestamps(
[(frame.word, frame.timestamp)], frame.context_id
[(frame.word, frame.timestamp)],
frame.context_id,
includes_inter_frame_spaces=frame.includes_inter_frame_spaces,
)
continue
elif isinstance(frame, TTSAudioRawFrame):

View File

@@ -208,6 +208,103 @@ class MockWebSocketPauseTTSService(TTSService):
yield
class _MockWordTimestampHttpTTSService(TTSService):
"""HTTP-style TTS: yields audio synchronously, calls add_word_timestamps first.
``word_times`` pins the exact tokens and their timestamps. When omitted the
service splits the input text on spaces, assigning 0.1 s gaps.
"""
def __init__(
self,
includes_inter_frame_spaces: bool = False,
word_times: list[tuple[str, float]] | None = None,
**kwargs,
):
super().__init__(
push_start_frame=True,
push_stop_frames=True,
push_text_frames=False,
sample_rate=_SAMPLE_RATE,
**kwargs,
)
self._includes_inter_frame_spaces = includes_inter_frame_spaces
self._word_times = word_times
def can_generate_metrics(self) -> bool:
return False
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
word_times = self._word_times or [(w, i * 0.1) for i, w in enumerate(text.split())]
await self.add_word_timestamps(
word_times,
context_id=context_id,
includes_inter_frame_spaces=self._includes_inter_frame_spaces,
)
yield TTSAudioRawFrame(
audio=_FAKE_AUDIO,
sample_rate=_SAMPLE_RATE,
num_channels=1,
context_id=context_id,
)
class _MockWordTimestampWSTTSService(TTSService):
"""WebSocket-style TTS: delivers audio asynchronously via the audio context.
Word timestamps are enqueued as ``_WordTimestampEntry`` items (audio context
already exists at call time) and processed by ``_handle_audio_context`` in
playback order.
``word_times`` pins the exact tokens and their timestamps. When omitted the
service splits the input text on spaces, assigning 0.1 s gaps.
"""
def __init__(
self,
includes_inter_frame_spaces: bool = False,
word_times: list[tuple[str, float]] | None = None,
**kwargs,
):
super().__init__(
push_start_frame=True,
push_text_frames=False,
pause_frame_processing=False,
sample_rate=_SAMPLE_RATE,
**kwargs,
)
self._includes_inter_frame_spaces = includes_inter_frame_spaces
self._word_times = word_times
def can_generate_metrics(self) -> bool:
return False
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
async def _deliver():
await asyncio.sleep(0.01)
word_times = self._word_times or [(w, i * 0.1) for i, w in enumerate(text.split())]
await self.add_word_timestamps(
word_times,
context_id=context_id,
includes_inter_frame_spaces=self._includes_inter_frame_spaces,
)
await self.append_to_audio_context(
context_id,
TTSAudioRawFrame(
audio=_FAKE_AUDIO,
sample_rate=_SAMPLE_RATE,
num_channels=1,
context_id=context_id,
),
)
await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id))
await self.remove_audio_context(context_id)
self.create_task(_deliver(), name=f"mock_ws_word_deliver_{context_id}")
if False:
yield
# ---------------------------------------------------------------------------
# Assertion helper
# ---------------------------------------------------------------------------
@@ -406,5 +503,159 @@ async def test_http_push_text_llm_response_end_after_tts_text():
)
@pytest.mark.asyncio
async def test_http_word_timestamps_verbatim_tokens():
"""HTTP path: text, PTS order, flag, and text-before-audio are all verified.
Word timestamps arrive in the audio context queue before the audio frame.
_handle_audio_context caches them, then flushes when the first audio frame
arrives (start_word_timestamps), so TTSTextFrames must be emitted before
the TTSAudioRawFrame in the downstream sequence.
"""
word_times = [("hello", 0.0), ("world", 0.2)]
tts = _MockWordTimestampHttpTTSService(
includes_inter_frame_spaces=True,
word_times=word_times,
)
frames_received = await run_test(
tts,
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
)
down = frames_received[0]
tts_text_frames = [f for f in down if isinstance(f, TTSTextFrame)]
audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)]
assert [f.text for f in tts_text_frames] == ["hello", "world"]
assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames)
pts_values = [f.pts for f in tts_text_frames]
assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), (
f"PTS values must be strictly increasing, got {pts_values}"
)
# TTSTextFrames must precede the audio frame (they are flushed from cache
# at the moment the first audio chunk sets the timestamp baseline).
last_text_idx = max(down.index(f) for f in tts_text_frames)
first_audio_idx = down.index(audio_frames[0])
assert last_text_idx < first_audio_idx, (
"TTSTextFrames must appear before TTSAudioRawFrame in the downstream sequence"
)
@pytest.mark.asyncio
async def test_http_word_timestamps_punctuation_tokens():
"""Verbatim punctuation tokens are preserved with flag=True; default flag is False.
Models the Inworld API scenario: the TTS returns tokens exactly as sent.
Space placement rule:
- word-follows-word: space is the leading char of the next word (e.g. " world")
- word-follows-punctuation: space is the trailing char of the punctuation token
(e.g. "! "), so the following word token carries no leading space.
The flag must reach every frame and the text must not be modified.
Also acts as a regression guard that flag=False is the default.
"""
verbatim_tokens = [
("hello", 0.0),
(" world", 0.15),
("! ", 0.3),
("How", 0.45),
(" are", 0.6),
(" you", 0.75),
("?", 0.9),
]
expected_texts = ["hello", " world", "! ", "How", " are", " you", "?"]
# With flag=True: all tokens verbatim, all frames carry the flag.
tts_ifs = _MockWordTimestampHttpTTSService(
includes_inter_frame_spaces=True,
word_times=verbatim_tokens,
)
frames_ifs = await run_test(
tts_ifs,
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
)
text_frames_ifs = [f for f in frames_ifs[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in text_frames_ifs] == expected_texts, (
"Verbatim tokens must not be modified"
)
assert all(f.includes_inter_frame_spaces is True for f in text_frames_ifs)
# With flag=False (default): same tokens, flag must be False on every frame.
tts_plain = _MockWordTimestampHttpTTSService(
word_times=verbatim_tokens,
)
frames_plain = await run_test(
tts_plain,
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
)
text_frames_plain = [f for f in frames_plain[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in text_frames_plain] == expected_texts
assert all(f.includes_inter_frame_spaces is False for f in text_frames_plain)
@pytest.mark.asyncio
async def test_websocket_word_timestamps_verbatim_tokens():
"""WebSocket path: _WordTimestampEntry carries verbatim text, PTS, and flag.
Unlike the HTTP path the word timestamps are sent asynchronously from a
background task. They arrive before the audio frame and are cached until
start_word_timestamps() fires, so the same text-before-audio ordering
property must hold.
"""
word_times = [("hello", 0.0), ("world", 0.2)]
tts = _MockWordTimestampWSTTSService(
includes_inter_frame_spaces=True,
word_times=word_times,
)
frames_received = await run_test(
tts,
frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)],
)
down = frames_received[0]
tts_text_frames = [f for f in down if isinstance(f, TTSTextFrame)]
audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)]
assert [f.text for f in tts_text_frames] == ["hello", "world"]
assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames)
pts_values = [f.pts for f in tts_text_frames]
assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), (
f"PTS values must be strictly increasing, got {pts_values}"
)
last_text_idx = max(down.index(f) for f in tts_text_frames)
first_audio_idx = down.index(audio_frames[0])
assert last_text_idx < first_audio_idx, (
"TTSTextFrames must appear before TTSAudioRawFrame in the downstream sequence"
)
@pytest.mark.asyncio
async def test_websocket_word_timestamps_punctuation_tokens():
"""WebSocket path: verbatim punctuation tokens reach TTSTextFrame unchanged."""
verbatim_tokens = [
("hello", 0.0),
(" world", 0.15),
("! ", 0.3),
("How", 0.45),
(" are", 0.6),
(" you", 0.75),
("?", 0.9),
]
tts = _MockWordTimestampWSTTSService(
includes_inter_frame_spaces=True,
word_times=verbatim_tokens,
)
frames_received = await run_test(
tts,
frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)],
)
text_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)]
assert [f.text for f in text_frames] == ["hello", " world", "! ", "How", " are", " you", "?"], (
"Verbatim tokens must not be modified"
)
assert all(f.includes_inter_frame_spaces is True for f in text_frames)
if __name__ == "__main__":
unittest.main()