aggregators: clear accumulated responses if interruption happens

This commit is contained in:
Aleix Conchillo Flaqué
2024-05-19 10:20:17 -07:00
parent c0d5054798
commit c3bfcbd562
2 changed files with 30 additions and 22 deletions

View File

@@ -13,6 +13,7 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMMessagesFrame,
LLMResponseStartFrame,
StartInterruptionFrame,
TextFrame,
LLMResponseEndFrame,
TranscriptionFrame,
@@ -40,12 +41,9 @@ class LLMResponseAggregator(FrameProcessor):
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
self._aggregation = ""
self._aggregating = False
# Reset our accumulator state.
self._reset()
#
# Frame processor
@@ -96,6 +94,9 @@ class LLMResponseAggregator(FrameProcessor):
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
elif isinstance(frame, StartInterruptionFrame):
self._reset()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
@@ -108,12 +109,15 @@ class LLMResponseAggregator(FrameProcessor):
frame = LLMMessagesFrame(self._messages)
await self.push_frame(frame)
# Reset
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
# Reset our accumulator state.
self._reset()
def _reset(self):
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
class LLMAssistantResponseAggregator(LLMResponseAggregator):

View File

@@ -8,6 +8,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.frames.frames import (
Frame,
InterimTranscriptionFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
@@ -56,12 +57,9 @@ class ResponseAggregator(FrameProcessor):
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
self._aggregation = ""
self._aggregating = False
# Reset our accumulator state.
self._reset()
#
# Frame processor
@@ -112,6 +110,9 @@ class ResponseAggregator(FrameProcessor):
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
elif isinstance(frame, StartInterruptionFrame):
self._reset()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
@@ -122,12 +123,15 @@ class ResponseAggregator(FrameProcessor):
if len(self._aggregation) > 0:
await self.push_frame(TextFrame(self._aggregation.strip()))
# Reset
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
# Reset our accumulator state.
self._reset()
def _reset(self):
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
class UserResponseAggregator(ResponseAggregator):