From e3c965f4d50ca667b0f2e20be463641ac30d6987 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 28 Apr 2025 11:02:09 -0700 Subject: [PATCH] TTSService: do not push LLMFullResponseEndFrame if not needed --- CHANGELOG.md | 3 +++ src/pipecat/services/cartesia/tts.py | 4 +--- src/pipecat/services/elevenlabs/tts.py | 4 ++-- src/pipecat/services/rime/tts.py | 2 +- src/pipecat/services/tts_service.py | 16 +++++++++++----- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c14be235..920c944b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -71,6 +71,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed a TTS services issue that could cause assistant output not to be + aggregated to the context when also using `TTSSpeakFrame`s. + - Fixed an issue where the `SmartTurnMetricsData` was reporting 0ms for inference and processing time when using the `FalSmartTurnAnalyzer`. diff --git a/src/pipecat/services/cartesia/tts.py b/src/pipecat/services/cartesia/tts.py index 2acabd541..3f7df063b 100644 --- a/src/pipecat/services/cartesia/tts.py +++ b/src/pipecat/services/cartesia/tts.py @@ -250,9 +250,7 @@ class CartesiaTTSService(AudioContextWordTTSService): continue if msg["type"] == "done": await self.stop_ttfb_metrics() - await self.add_word_timestamps( - [("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)] - ) + await self.add_word_timestamps([("TTSStoppedFrame", 0), ("Reset", 0)]) await self.remove_audio_context(msg["context_id"]) elif msg["type"] == "timestamps": await self.add_word_timestamps( diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index b9a22b388..4362fcdc9 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -287,7 +287,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)): self._started = False if isinstance(frame, TTSStoppedFrame): - await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)]) + await self.add_word_timestamps([("Reset", 0)]) async def _connect(self): await self._connect_websocket() @@ -526,7 +526,7 @@ class ElevenLabsHttpTTSService(WordTTSService): self._reset_state() if isinstance(frame, TTSStoppedFrame): - await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)]) + await self.add_word_timestamps([("Reset", 0)]) elif isinstance(frame, LLMFullResponseEndFrame): # End of turn - reset previous text diff --git a/src/pipecat/services/rime/tts.py b/src/pipecat/services/rime/tts.py index d3e8e5adb..83fd3fa62 100644 --- a/src/pipecat/services/rime/tts.py +++ b/src/pipecat/services/rime/tts.py @@ -304,7 +304,7 @@ class RimeTTSService(AudioContextWordTTSService): await super().push_frame(frame, direction) if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)): if isinstance(frame, TTSStoppedFrame): - await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)]) + await self.add_word_timestamps([("Reset", 0)]) async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: """Generate speech from text. diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index 1da3c6814..a2b8e3f90 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -19,6 +19,7 @@ from pipecat.frames.frames import ( Frame, InterimTranscriptionFrame, LLMFullResponseEndFrame, + LLMFullResponseStartFrame, StartFrame, StartInterruptionFrame, TextFrame, @@ -308,6 +309,7 @@ class WordTTSService(TTSService): self._initial_word_timestamp = -1 self._words_queue = asyncio.Queue() self._words_task = None + self._llm_response_started: bool = False def start_word_timestamps(self): if self._initial_word_timestamp == -1: @@ -335,11 +337,14 @@ class WordTTSService(TTSService): async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)): + if isinstance(frame, LLMFullResponseStartFrame): + self._llm_response_started = True + elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)): await self.flush_audio() async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): await super()._handle_interruption(frame, direction) + self._llm_response_started = False self.reset_word_timestamps() def _create_words_task(self): @@ -354,13 +359,14 @@ class WordTTSService(TTSService): async def _words_task_handler(self): last_pts = 0 while True: + frame = None (word, timestamp) = await self._words_queue.get() if word == "Reset" and timestamp == 0: self.reset_word_timestamps() - frame = None - elif word == "LLMFullResponseEndFrame" and timestamp == 0: - frame = LLMFullResponseEndFrame() - frame.pts = last_pts + if self._llm_response_started: + self._llm_response_started = False + frame = LLMFullResponseEndFrame() + frame.pts = last_pts elif word == "TTSStoppedFrame" and timestamp == 0: frame = TTSStoppedFrame() frame.pts = last_pts