Merge pull request #1301 from pipecat-ai/aleix/example-22d-fix-llm-aggregator
This commit is contained in:
@@ -23,7 +23,6 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
@@ -37,7 +36,7 @@ from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import LLMResponseAggregator
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -432,7 +431,11 @@ class CompletenessCheck(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
if self._idle_task:
|
||||
await self.cancel_task(self._idle_task)
|
||||
self._idle_task = None
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
if self._idle_task:
|
||||
await self.cancel_task(self._idle_task)
|
||||
elif isinstance(frame, TextFrame) and frame.text.startswith("YES"):
|
||||
@@ -474,19 +477,11 @@ class CompletenessCheck(FrameProcessor):
|
||||
self._idle_task = None
|
||||
|
||||
|
||||
class UserAggregatorBuffer(LLMResponseAggregator):
|
||||
class LLMAggregatorBuffer(LLMAssistantResponseAggregator):
|
||||
"""Buffers the output of the transcription LLM. Used by the bot output gate."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
messages=None,
|
||||
role=None,
|
||||
start_frame=LLMFullResponseStartFrame,
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
expect_stripped_words=False,
|
||||
)
|
||||
super().__init__(expect_stripped_words=False)
|
||||
self._transcription = ""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -544,7 +539,7 @@ class OutputGate(FrameProcessor):
|
||||
self,
|
||||
notifier: BaseNotifier,
|
||||
context: OpenAILLMContext,
|
||||
user_transcription_buffer: "UserAggregatorBuffer",
|
||||
llm_transcription_buffer: LLMAggregatorBuffer,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -552,7 +547,7 @@ class OutputGate(FrameProcessor):
|
||||
self._frames_buffer = []
|
||||
self._notifier = notifier
|
||||
self._context = context
|
||||
self._transcription_buffer = user_transcription_buffer
|
||||
self._transcription_buffer = llm_transcription_buffer
|
||||
self._gate_task = None
|
||||
|
||||
def close_gate(self):
|
||||
@@ -699,10 +694,10 @@ async def main():
|
||||
|
||||
conversation_audio_context_assembler = ConversationAudioContextAssembler(context=context)
|
||||
|
||||
user_aggregator_buffer = UserAggregatorBuffer()
|
||||
llm_aggregator_buffer = LLMAggregatorBuffer()
|
||||
|
||||
bot_output_gate = OutputGate(
|
||||
notifier=notifier, context=context, user_transcription_buffer=user_aggregator_buffer
|
||||
notifier=notifier, context=context, llm_transcription_buffer=llm_aggregator_buffer
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
@@ -723,7 +718,7 @@ async def main():
|
||||
],
|
||||
[
|
||||
tx_llm,
|
||||
user_aggregator_buffer,
|
||||
llm_aggregator_buffer,
|
||||
],
|
||||
)
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user