From d58f398bc4df57a15e4be6dd89368ffd08fb5f18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 26 Feb 2025 13:15:07 -0800 Subject: [PATCH] examples: fix for 22d-natural-conversation-gemini-audio.py --- .../22d-natural-conversation-gemini-audio.py | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/foundational/22d-natural-conversation-gemini-audio.py b/examples/foundational/22d-natural-conversation-gemini-audio.py index 6313f951d..0041ff530 100644 --- a/examples/foundational/22d-natural-conversation-gemini-audio.py +++ b/examples/foundational/22d-natural-conversation-gemini-audio.py @@ -23,7 +23,6 @@ from pipecat.frames.frames import ( FunctionCallInProgressFrame, FunctionCallResultFrame, InputAudioRawFrame, - LLMFullResponseEndFrame, LLMFullResponseStartFrame, StartFrame, StartInterruptionFrame, @@ -37,7 +36,7 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.llm_response import LLMResponseAggregator +from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -432,7 +431,11 @@ class CompletenessCheck(FrameProcessor): async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if isinstance(frame, UserStartedSpeakingFrame): + if isinstance(frame, (EndFrame, CancelFrame)): + if self._idle_task: + await self.cancel_task(self._idle_task) + self._idle_task = None + elif isinstance(frame, UserStartedSpeakingFrame): if self._idle_task: await self.cancel_task(self._idle_task) elif isinstance(frame, TextFrame) and frame.text.startswith("YES"): @@ -474,19 +477,11 @@ class CompletenessCheck(FrameProcessor): self._idle_task = None -class UserAggregatorBuffer(LLMResponseAggregator): +class LLMAggregatorBuffer(LLMAssistantResponseAggregator): """Buffers the output of the transcription LLM. Used by the bot output gate.""" def __init__(self, **kwargs): - super().__init__( - messages=None, - role=None, - start_frame=LLMFullResponseStartFrame, - end_frame=LLMFullResponseEndFrame, - accumulator_frame=TextFrame, - handle_interruptions=True, - expect_stripped_words=False, - ) + super().__init__(expect_stripped_words=False) self._transcription = "" async def process_frame(self, frame: Frame, direction: FrameDirection): @@ -544,7 +539,7 @@ class OutputGate(FrameProcessor): self, notifier: BaseNotifier, context: OpenAILLMContext, - user_transcription_buffer: "UserAggregatorBuffer", + llm_transcription_buffer: LLMAggregatorBuffer, **kwargs, ): super().__init__(**kwargs) @@ -552,7 +547,7 @@ class OutputGate(FrameProcessor): self._frames_buffer = [] self._notifier = notifier self._context = context - self._transcription_buffer = user_transcription_buffer + self._transcription_buffer = llm_transcription_buffer self._gate_task = None def close_gate(self): @@ -699,10 +694,10 @@ async def main(): conversation_audio_context_assembler = ConversationAudioContextAssembler(context=context) - user_aggregator_buffer = UserAggregatorBuffer() + llm_aggregator_buffer = LLMAggregatorBuffer() bot_output_gate = OutputGate( - notifier=notifier, context=context, user_transcription_buffer=user_aggregator_buffer + notifier=notifier, context=context, llm_transcription_buffer=llm_aggregator_buffer ) pipeline = Pipeline( @@ -723,7 +718,7 @@ async def main(): ], [ tx_llm, - user_aggregator_buffer, + llm_aggregator_buffer, ], ) ],