Compare commits

...

1 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
bff5b3e562 LLMAssistantContextAggregator: always aggregate TTSTextFrame
This is the assistant aggregator and we should add everything that is being
spoekn. Because of that we should use TTSTextFrame because those are the frames
that are actually spoken. We should send the aggregation as soon as the bot
stops speaking.

So, we don't need to handle `LLMFullResponseStartFrame` and
`LLMFullResponseEndFrame` anymore.
2025-02-20 17:41:31 -08:00
2 changed files with 53 additions and 71 deletions

View File

@@ -10,14 +10,13 @@ from abc import abstractmethod
from typing import List
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
LLMMessagesFrame,
LLMMessagesUpdateFrame,
@@ -26,6 +25,7 @@ from pipecat.frames.frames import (
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -352,8 +352,8 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
class LLMAssistantContextAggregator(LLMContextResponseAggregator):
"""This is an assistant LLM aggregator that uses an LLM context to store the
conversation. It aggregates text frames received between
`LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`.
conversation. It aggregates text frames spoken by the TTS service and pushes
the context when the bot stops speaking..
"""
@@ -361,8 +361,6 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
super().__init__(context=context, role="assistant", **kwargs)
self._expect_stripped_words = expect_stripped_words
self._started = False
self.reset()
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -373,11 +371,10 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
# Reset anyways
self.reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseStartFrame):
await self._handle_llm_start(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame)
elif isinstance(frame, TextFrame):
elif isinstance(frame, BotStoppedSpeakingFrame):
await self._handle_bot_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, TTSTextFrame):
await self._handle_text(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
self.add_messages(frame.messages)
@@ -388,17 +385,10 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
else:
await self.push_frame(frame, direction)
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
self._started = True
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
self._started = False
async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame):
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
return
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:

View File

@@ -9,15 +9,14 @@ import unittest
import google.ai.generativelanguage as glm
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
OpenAILLMContextAssistantTimestampFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -428,20 +427,6 @@ class BaseTestAssistantContextAggreagator:
):
assert context.messages[index]["content"] == content
async def test_empty(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
expected_down_frames = []
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
async def test_single_text(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
@@ -449,11 +434,11 @@ class BaseTestAssistantContextAggreagator:
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello Pipecat!"),
LLMFullResponseEndFrame(),
TTSTextFrame(text="Hello Pipecat!"),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_down_frames = [BotStoppedSpeakingFrame, *self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -468,14 +453,14 @@ class BaseTestAssistantContextAggreagator:
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),
TextFrame(text="Pipecat. "),
TextFrame(text="How are "),
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
TTSTextFrame(text="Hello "),
TTSTextFrame(text="Pipecat. "),
TTSTextFrame(text="How are "),
TTSTextFrame(text="you?"),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_down_frames = [BotStoppedSpeakingFrame, *self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -490,14 +475,14 @@ class BaseTestAssistantContextAggreagator:
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello"),
TextFrame(text="Pipecat."),
TextFrame(text="How are"),
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="Pipecat."),
TTSTextFrame(text="How are"),
TTSTextFrame(text="you?"),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_down_frames = [BotStoppedSpeakingFrame, *self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -512,16 +497,21 @@ class BaseTestAssistantContextAggreagator:
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),
TextFrame(text="Pipecat."),
LLMFullResponseEndFrame(),
LLMFullResponseStartFrame(),
TextFrame(text="How are "),
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
TTSTextFrame(text="Hello "),
TTSTextFrame(text="Pipecat."),
SleepFrame(),
BotStoppedSpeakingFrame(),
TTSTextFrame(text="How are "),
TTSTextFrame(text="you?"),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
expected_down_frames = [
BotStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
BotStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -537,20 +527,22 @@ class BaseTestAssistantContextAggreagator:
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),
TextFrame(text="Pipecat."),
LLMFullResponseEndFrame(),
TTSTextFrame(text="Hello "),
TTSTextFrame(text="Pipecat."),
SleepFrame(),
BotStoppedSpeakingFrame(),
SleepFrame(AGGREGATION_SLEEP),
StartInterruptionFrame(),
LLMFullResponseStartFrame(),
TextFrame(text="How are "),
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
TTSTextFrame(text="How are "),
TTSTextFrame(text="you?"),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
expected_down_frames = [
BotStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
StartInterruptionFrame,
BotStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(