diff --git a/changelog/4330.changed.md b/changelog/4330.changed.md new file mode 100644 index 000000000..97eb2406c --- /dev/null +++ b/changelog/4330.changed.md @@ -0,0 +1,2 @@ +- added `includes_inter_frame_spaces` to `TTSService` and `WebsocketTTSService`, with default to `false` to preserve existing behaviors. +- `InworldTTSService` to set `includes_inter_frame_spaces` to `true` to stop inserting spaces between words from server, which already include spaces. \ No newline at end of file diff --git a/src/pipecat/services/inworld/tts.py b/src/pipecat/services/inworld/tts.py index 65922e3a1..c823d6c7e 100644 --- a/src/pipecat/services/inworld/tts.py +++ b/src/pipecat/services/inworld/tts.py @@ -410,7 +410,9 @@ class InworldHttpTTSService(TTSService): word_times, chunk_end_time = self._calculate_word_times(timestamp_info) if word_times: self._current_run_had_timestamps = True - await self.add_word_timestamps(word_times, context_id) + await self.add_word_timestamps( + word_times, context_id, includes_inter_frame_spaces=True + ) # Track the maximum end time across all chunks utterance_duration = max(utterance_duration, chunk_end_time) @@ -447,7 +449,9 @@ class InworldHttpTTSService(TTSService): word_times, chunk_end_time = self._calculate_word_times(timestamp_info) if word_times: self._current_run_had_timestamps = True - await self.add_word_timestamps(word_times, context_id) + await self.add_word_timestamps( + word_times, context_id, includes_inter_frame_spaces=True + ) utterance_duration = chunk_end_time audio_data = base64.b64decode(response_data["audioContent"]) @@ -1013,7 +1017,9 @@ class InworldTTSService(WebsocketTTSService): if word_times: if ctx_id: self._contexts_with_timestamps.add(ctx_id) - await self.add_word_timestamps(word_times, ctx_id) + await self.add_word_timestamps( + word_times, ctx_id, includes_inter_frame_spaces=True + ) # Handle flush completion, which indicates the end of a generation if "flushCompleted" in result: diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index fe4790cbb..cfc8d784f 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -96,6 +96,7 @@ class _WordTimestampEntry: word: str timestamp: float context_id: str + includes_inter_frame_spaces: bool = False class TTSService(AIService): @@ -1049,8 +1050,10 @@ class TTSService(AIService): if self._initial_word_times: cached = self._initial_word_times.copy() self._initial_word_times = [] - for word, timestamp_seconds, ctx_id in cached: - await self._add_word_timestamps([(word, timestamp_seconds)], ctx_id) + for word, timestamp_seconds, ctx_id, ifs in cached: + await self._add_word_timestamps( + [(word, timestamp_seconds)], ctx_id, includes_inter_frame_spaces=ifs + ) async def reset_word_timestamps(self): """Reset word timestamp tracking.""" @@ -1060,7 +1063,10 @@ class TTSService(AIService): self._initial_word_times = [] async def add_word_timestamps( - self, word_times: list[tuple[str, float]], context_id: str | None = None + self, + word_times: list[tuple[str, float]], + context_id: str | None = None, + includes_inter_frame_spaces: bool = False, ): """Add word timestamps for processing. @@ -1072,6 +1078,9 @@ class TTSService(AIService): Args: word_times: List of (word, timestamp) tuples where timestamp is in seconds. context_id: Unique identifier for the TTS context. + includes_inter_frame_spaces: When True, the tokens already embed inter-word + spacing (spaces and punctuation are part of the token text). Downstream + consumers must not inject additional spaces between tokens. """ if context_id and self.audio_context_available(context_id): for word, timestamp in word_times: @@ -1081,13 +1090,21 @@ class TTSService(AIService): word=word, timestamp=timestamp, context_id=context_id, + includes_inter_frame_spaces=includes_inter_frame_spaces, ), ) else: - await self._add_word_timestamps(word_times=word_times, context_id=context_id) + await self._add_word_timestamps( + word_times=word_times, + context_id=context_id, + includes_inter_frame_spaces=includes_inter_frame_spaces, + ) async def _add_word_timestamps( - self, word_times: list[tuple[str, float]], context_id: str | None = None + self, + word_times: list[tuple[str, float]], + context_id: str | None = None, + includes_inter_frame_spaces: bool = False, ): """Process word timestamps directly, building and pushing TTSTextFrames inline. @@ -1103,11 +1120,12 @@ class TTSService(AIService): ts_ns = seconds_to_nanoseconds(timestamp) if self._initial_word_timestamp == -1: # Cache until we have audio and can compute PTS. - self._initial_word_times.append((word, timestamp, context_id)) + self._initial_word_times.append( + (word, timestamp, context_id, includes_inter_frame_spaces) + ) else: - # Assumption: word-by-word text frames don't include spaces, so - # we can rely on the default includes_inter_frame_spaces=False frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD) + frame.includes_inter_frame_spaces = includes_inter_frame_spaces frame.pts = self._initial_word_timestamp + ts_ns frame.context_id = context_id if context_id in self._tts_contexts: @@ -1310,7 +1328,9 @@ class TTSService(AIService): # Route word timestamps through _add_word_timestamps so they are # processed in playback order alongside audio frames. await self._add_word_timestamps( - [(frame.word, frame.timestamp)], frame.context_id + [(frame.word, frame.timestamp)], + frame.context_id, + includes_inter_frame_spaces=frame.includes_inter_frame_spaces, ) continue elif isinstance(frame, TTSAudioRawFrame): diff --git a/tests/test_tts_frame_ordering.py b/tests/test_tts_frame_ordering.py index bdcc588be..acc1cd2fd 100644 --- a/tests/test_tts_frame_ordering.py +++ b/tests/test_tts_frame_ordering.py @@ -208,6 +208,103 @@ class MockWebSocketPauseTTSService(TTSService): yield +class _MockWordTimestampHttpTTSService(TTSService): + """HTTP-style TTS: yields audio synchronously, calls add_word_timestamps first. + + ``word_times`` pins the exact tokens and their timestamps. When omitted the + service splits the input text on spaces, assigning 0.1 s gaps. + """ + + def __init__( + self, + includes_inter_frame_spaces: bool = False, + word_times: list[tuple[str, float]] | None = None, + **kwargs, + ): + super().__init__( + push_start_frame=True, + push_stop_frames=True, + push_text_frames=False, + sample_rate=_SAMPLE_RATE, + **kwargs, + ) + self._includes_inter_frame_spaces = includes_inter_frame_spaces + self._word_times = word_times + + def can_generate_metrics(self) -> bool: + return False + + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: + word_times = self._word_times or [(w, i * 0.1) for i, w in enumerate(text.split())] + await self.add_word_timestamps( + word_times, + context_id=context_id, + includes_inter_frame_spaces=self._includes_inter_frame_spaces, + ) + yield TTSAudioRawFrame( + audio=_FAKE_AUDIO, + sample_rate=_SAMPLE_RATE, + num_channels=1, + context_id=context_id, + ) + + +class _MockWordTimestampWSTTSService(TTSService): + """WebSocket-style TTS: delivers audio asynchronously via the audio context. + + Word timestamps are enqueued as ``_WordTimestampEntry`` items (audio context + already exists at call time) and processed by ``_handle_audio_context`` in + playback order. + + ``word_times`` pins the exact tokens and their timestamps. When omitted the + service splits the input text on spaces, assigning 0.1 s gaps. + """ + + def __init__( + self, + includes_inter_frame_spaces: bool = False, + word_times: list[tuple[str, float]] | None = None, + **kwargs, + ): + super().__init__( + push_start_frame=True, + push_text_frames=False, + pause_frame_processing=False, + sample_rate=_SAMPLE_RATE, + **kwargs, + ) + self._includes_inter_frame_spaces = includes_inter_frame_spaces + self._word_times = word_times + + def can_generate_metrics(self) -> bool: + return False + + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: + async def _deliver(): + await asyncio.sleep(0.01) + word_times = self._word_times or [(w, i * 0.1) for i, w in enumerate(text.split())] + await self.add_word_timestamps( + word_times, + context_id=context_id, + includes_inter_frame_spaces=self._includes_inter_frame_spaces, + ) + await self.append_to_audio_context( + context_id, + TTSAudioRawFrame( + audio=_FAKE_AUDIO, + sample_rate=_SAMPLE_RATE, + num_channels=1, + context_id=context_id, + ), + ) + await self.append_to_audio_context(context_id, TTSStoppedFrame(context_id=context_id)) + await self.remove_audio_context(context_id) + + self.create_task(_deliver(), name=f"mock_ws_word_deliver_{context_id}") + if False: + yield + + # --------------------------------------------------------------------------- # Assertion helper # --------------------------------------------------------------------------- @@ -406,5 +503,159 @@ async def test_http_push_text_llm_response_end_after_tts_text(): ) +@pytest.mark.asyncio +async def test_http_word_timestamps_verbatim_tokens(): + """HTTP path: text, PTS order, flag, and text-before-audio are all verified. + + Word timestamps arrive in the audio context queue before the audio frame. + _handle_audio_context caches them, then flushes when the first audio frame + arrives (start_word_timestamps), so TTSTextFrames must be emitted before + the TTSAudioRawFrame in the downstream sequence. + """ + word_times = [("hello", 0.0), ("world", 0.2)] + tts = _MockWordTimestampHttpTTSService( + includes_inter_frame_spaces=True, + word_times=word_times, + ) + frames_received = await run_test( + tts, + frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)], + ) + down = frames_received[0] + tts_text_frames = [f for f in down if isinstance(f, TTSTextFrame)] + audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)] + + assert [f.text for f in tts_text_frames] == ["hello", "world"] + assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames) + + pts_values = [f.pts for f in tts_text_frames] + assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), ( + f"PTS values must be strictly increasing, got {pts_values}" + ) + + # TTSTextFrames must precede the audio frame (they are flushed from cache + # at the moment the first audio chunk sets the timestamp baseline). + last_text_idx = max(down.index(f) for f in tts_text_frames) + first_audio_idx = down.index(audio_frames[0]) + assert last_text_idx < first_audio_idx, ( + "TTSTextFrames must appear before TTSAudioRawFrame in the downstream sequence" + ) + + +@pytest.mark.asyncio +async def test_http_word_timestamps_punctuation_tokens(): + """Verbatim punctuation tokens are preserved with flag=True; default flag is False. + + Models the Inworld API scenario: the TTS returns tokens exactly as sent. + Space placement rule: + - word-follows-word: space is the leading char of the next word (e.g. " world") + - word-follows-punctuation: space is the trailing char of the punctuation token + (e.g. "! "), so the following word token carries no leading space. + The flag must reach every frame and the text must not be modified. + Also acts as a regression guard that flag=False is the default. + """ + verbatim_tokens = [ + ("hello", 0.0), + (" world", 0.15), + ("! ", 0.3), + ("How", 0.45), + (" are", 0.6), + (" you", 0.75), + ("?", 0.9), + ] + expected_texts = ["hello", " world", "! ", "How", " are", " you", "?"] + + # With flag=True: all tokens verbatim, all frames carry the flag. + tts_ifs = _MockWordTimestampHttpTTSService( + includes_inter_frame_spaces=True, + word_times=verbatim_tokens, + ) + frames_ifs = await run_test( + tts_ifs, + frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)], + ) + text_frames_ifs = [f for f in frames_ifs[0] if isinstance(f, TTSTextFrame)] + assert [f.text for f in text_frames_ifs] == expected_texts, ( + "Verbatim tokens must not be modified" + ) + assert all(f.includes_inter_frame_spaces is True for f in text_frames_ifs) + + # With flag=False (default): same tokens, flag must be False on every frame. + tts_plain = _MockWordTimestampHttpTTSService( + word_times=verbatim_tokens, + ) + frames_plain = await run_test( + tts_plain, + frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)], + ) + text_frames_plain = [f for f in frames_plain[0] if isinstance(f, TTSTextFrame)] + assert [f.text for f in text_frames_plain] == expected_texts + assert all(f.includes_inter_frame_spaces is False for f in text_frames_plain) + + +@pytest.mark.asyncio +async def test_websocket_word_timestamps_verbatim_tokens(): + """WebSocket path: _WordTimestampEntry carries verbatim text, PTS, and flag. + + Unlike the HTTP path the word timestamps are sent asynchronously from a + background task. They arrive before the audio frame and are cached until + start_word_timestamps() fires, so the same text-before-audio ordering + property must hold. + """ + word_times = [("hello", 0.0), ("world", 0.2)] + tts = _MockWordTimestampWSTTSService( + includes_inter_frame_spaces=True, + word_times=word_times, + ) + frames_received = await run_test( + tts, + frames_to_send=[TTSSpeakFrame(text="hello world", append_to_context=False)], + ) + down = frames_received[0] + tts_text_frames = [f for f in down if isinstance(f, TTSTextFrame)] + audio_frames = [f for f in down if isinstance(f, TTSAudioRawFrame)] + + assert [f.text for f in tts_text_frames] == ["hello", "world"] + assert all(f.includes_inter_frame_spaces is True for f in tts_text_frames) + + pts_values = [f.pts for f in tts_text_frames] + assert pts_values == sorted(pts_values) and len(set(pts_values)) == len(pts_values), ( + f"PTS values must be strictly increasing, got {pts_values}" + ) + + last_text_idx = max(down.index(f) for f in tts_text_frames) + first_audio_idx = down.index(audio_frames[0]) + assert last_text_idx < first_audio_idx, ( + "TTSTextFrames must appear before TTSAudioRawFrame in the downstream sequence" + ) + + +@pytest.mark.asyncio +async def test_websocket_word_timestamps_punctuation_tokens(): + """WebSocket path: verbatim punctuation tokens reach TTSTextFrame unchanged.""" + verbatim_tokens = [ + ("hello", 0.0), + (" world", 0.15), + ("! ", 0.3), + ("How", 0.45), + (" are", 0.6), + (" you", 0.75), + ("?", 0.9), + ] + tts = _MockWordTimestampWSTTSService( + includes_inter_frame_spaces=True, + word_times=verbatim_tokens, + ) + frames_received = await run_test( + tts, + frames_to_send=[TTSSpeakFrame(text="hello world! How are you?", append_to_context=False)], + ) + text_frames = [f for f in frames_received[0] if isinstance(f, TTSTextFrame)] + assert [f.text for f in text_frames] == ["hello", " world", "! ", "How", " are", " you", "?"], ( + "Verbatim tokens must not be modified" + ) + assert all(f.includes_inter_frame_spaces is True for f in text_frames) + + if __name__ == "__main__": unittest.main()