Merge pull request #1301 from pipecat-ai/aleix/example-22d-fix-llm-aggregator

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-26 22:39:48 -08:00
committed by GitHub

View File

@@ -23,7 +23,6 @@ from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
InputAudioRawFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
StartFrame,
StartInterruptionFrame,
@@ -37,7 +36,7 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import LLMResponseAggregator
from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
@@ -432,7 +431,11 @@ class CompletenessCheck(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
if isinstance(frame, (EndFrame, CancelFrame)):
if self._idle_task:
await self.cancel_task(self._idle_task)
self._idle_task = None
elif isinstance(frame, UserStartedSpeakingFrame):
if self._idle_task:
await self.cancel_task(self._idle_task)
elif isinstance(frame, TextFrame) and frame.text.startswith("YES"):
@@ -474,19 +477,11 @@ class CompletenessCheck(FrameProcessor):
self._idle_task = None
class UserAggregatorBuffer(LLMResponseAggregator):
class LLMAggregatorBuffer(LLMAssistantResponseAggregator):
"""Buffers the output of the transcription LLM. Used by the bot output gate."""
def __init__(self, **kwargs):
super().__init__(
messages=None,
role=None,
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
expect_stripped_words=False,
)
super().__init__(expect_stripped_words=False)
self._transcription = ""
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -544,7 +539,7 @@ class OutputGate(FrameProcessor):
self,
notifier: BaseNotifier,
context: OpenAILLMContext,
user_transcription_buffer: "UserAggregatorBuffer",
llm_transcription_buffer: LLMAggregatorBuffer,
**kwargs,
):
super().__init__(**kwargs)
@@ -552,7 +547,7 @@ class OutputGate(FrameProcessor):
self._frames_buffer = []
self._notifier = notifier
self._context = context
self._transcription_buffer = user_transcription_buffer
self._transcription_buffer = llm_transcription_buffer
self._gate_task = None
def close_gate(self):
@@ -699,10 +694,10 @@ async def main():
conversation_audio_context_assembler = ConversationAudioContextAssembler(context=context)
user_aggregator_buffer = UserAggregatorBuffer()
llm_aggregator_buffer = LLMAggregatorBuffer()
bot_output_gate = OutputGate(
notifier=notifier, context=context, user_transcription_buffer=user_aggregator_buffer
notifier=notifier, context=context, llm_transcription_buffer=llm_aggregator_buffer
)
pipeline = Pipeline(
@@ -723,7 +718,7 @@ async def main():
],
[
tx_llm,
user_aggregator_buffer,
llm_aggregator_buffer,
],
)
],