aggregators: clear accumulated responses if interruption happens
This commit is contained in:
@@ -13,6 +13,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMResponseStartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
LLMResponseEndFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -40,12 +41,9 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._end_frame = end_frame
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
|
||||
#
|
||||
# Frame processor
|
||||
@@ -96,6 +94,9 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._seen_interim_results = False
|
||||
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
|
||||
self._seen_interim_results = True
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -108,12 +109,15 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
|
||||
|
||||
class LLMAssistantResponseAggregator(LLMResponseAggregator):
|
||||
|
||||
@@ -8,6 +8,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -56,12 +57,9 @@ class ResponseAggregator(FrameProcessor):
|
||||
self._end_frame = end_frame
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
|
||||
#
|
||||
# Frame processor
|
||||
@@ -112,6 +110,9 @@ class ResponseAggregator(FrameProcessor):
|
||||
self._seen_interim_results = False
|
||||
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
|
||||
self._seen_interim_results = True
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -122,12 +123,15 @@ class ResponseAggregator(FrameProcessor):
|
||||
if len(self._aggregation) > 0:
|
||||
await self.push_frame(TextFrame(self._aggregation.strip()))
|
||||
|
||||
# Reset
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
|
||||
|
||||
class UserResponseAggregator(ResponseAggregator):
|
||||
|
||||
Reference in New Issue
Block a user