From 685f951ae218dfd7756229631e997dc09ef8951b Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 29 Apr 2025 15:40:18 -0400 Subject: [PATCH] Fix: GeminiMultimodalLiveLLMService was appending tokens to the context --- CHANGELOG.md | 3 ++ .../processors/aggregators/llm_response.py | 5 +-- .../services/gemini_multimodal_live/gemini.py | 34 +++++++------------ 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 08e9d30ed..c3f8fc3fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed an issue with `GeminiMultimodalLiveLLMService` where the context + contained tokens instead of words. + - Fixed an issue with HTTP Smart Turn handling, where the service returns a 500 error. Previously, this would cause an unhandled exception. Now, a 500 error is treated as an incomplete response. diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index fa163ac0a..5d8b6a32c 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -36,6 +36,7 @@ from pipecat.frames.frames import ( StartInterruptionFrame, TextFrame, TranscriptionFrame, + TTSTextFrame, UserImageRawFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, @@ -493,7 +494,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): await self._handle_llm_start(frame) elif isinstance(frame, LLMFullResponseEndFrame): await self._handle_llm_end(frame) - elif isinstance(frame, TextFrame): + elif isinstance(frame, TTSTextFrame): await self._handle_text(frame) elif isinstance(frame, LLMMessagesAppendFrame): self.add_messages(frame.messages) @@ -620,7 +621,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): self._started -= 1 await self.push_aggregation() - async def _handle_text(self, frame: TextFrame): + async def _handle_text(self, frame: TTSTextFrame): if not self._started: return diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 3881f7c7e..79cc556f5 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -344,7 +344,6 @@ class GeminiMultimodalLiveLLMService(LLMService): self._bot_is_speaking = False self._user_audio_buffer = bytearray() self._bot_audio_buffer = bytearray() - self._bot_text_buffer = "" self._sample_rate = 24000 @@ -427,7 +426,9 @@ class GeminiMultimodalLiveLLMService(LLMService): # async def _handle_interruption(self): - pass + self._bot_is_speaking = False + await self.push_frame(TTSStoppedFrame()) + await self.push_frame(LLMFullResponseEndFrame()) async def _handle_user_started_speaking(self, frame): self._user_is_speaking = True @@ -839,14 +840,6 @@ class GeminiMultimodalLiveLLMService(LLMService): if not part: return - text = part.text - if text: - if not self._bot_text_buffer: - await self.push_frame(LLMFullResponseStartFrame()) - - self._bot_text_buffer += text - await self.push_frame(LLMTextFrame(text=text)) - inline_data = part.inlineData if not inline_data: return @@ -861,6 +854,7 @@ class GeminiMultimodalLiveLLMService(LLMService): if not self._bot_is_speaking: self._bot_is_speaking = True await self.push_frame(TTSStartedFrame()) + await self.push_frame(LLMFullResponseStartFrame()) self._bot_audio_buffer.extend(audio) frame = TTSAudioRawFrame( @@ -886,24 +880,20 @@ class GeminiMultimodalLiveLLMService(LLMService): async def _handle_evt_turn_complete(self, evt): self._bot_is_speaking = False - text = self._bot_text_buffer - self._bot_text_buffer = "" - - if text: - await self.push_frame(LLMFullResponseEndFrame()) - await self.push_frame(TTSStoppedFrame()) + await self.push_frame(LLMFullResponseEndFrame()) async def _handle_evt_output_transcription(self, evt): if not evt.serverContent.outputTranscription: return text = evt.serverContent.outputTranscription.text - if text: - await self.push_frame(LLMFullResponseStartFrame()) - await self.push_frame(LLMTextFrame(text=text)) - await self.push_frame(TTSTextFrame(text=text)) - await self.push_frame(LLMFullResponseEndFrame()) + + if not text: + return + + await self.push_frame(LLMTextFrame(text=text)) + await self.push_frame(TTSTextFrame(text=text)) def create_context_aggregator( self, @@ -934,6 +924,6 @@ class GeminiMultimodalLiveLLMService(LLMService): GeminiMultimodalLiveContext.upgrade(context) user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params) - assistant_params.expect_stripped_words = True + assistant_params.expect_stripped_words = False assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params) return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)