Compare commits
1 Commits
hush/daily
...
aleix/llm-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bff5b3e562 |
@@ -10,14 +10,13 @@ from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
@@ -26,6 +25,7 @@ from pipecat.frames.frames import (
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -352,8 +352,8 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
"""This is an assistant LLM aggregator that uses an LLM context to store the
|
||||
conversation. It aggregates text frames received between
|
||||
`LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`.
|
||||
conversation. It aggregates text frames spoken by the TTS service and pushes
|
||||
the context when the bot stops speaking..
|
||||
|
||||
"""
|
||||
|
||||
@@ -361,8 +361,6 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
super().__init__(context=context, role="assistant", **kwargs)
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
self._started = False
|
||||
|
||||
self.reset()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -373,11 +371,10 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
# Reset anyways
|
||||
self.reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_start(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame)
|
||||
elif isinstance(frame, TextFrame):
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._handle_bot_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
await self._handle_text(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
@@ -388,17 +385,10 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started = True
|
||||
|
||||
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||
self._started = False
|
||||
async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame):
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
|
||||
@@ -9,15 +9,14 @@ import unittest
|
||||
import google.ai.generativelanguage as glm
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -428,20 +427,6 @@ class BaseTestAssistantContextAggreagator:
|
||||
):
|
||||
assert context.messages[index]["content"] == content
|
||||
|
||||
async def test_empty(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_single_text(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
@@ -449,11 +434,11 @@ class BaseTestAssistantContextAggreagator:
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello Pipecat!"),
|
||||
LLMFullResponseEndFrame(),
|
||||
TTSTextFrame(text="Hello Pipecat!"),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_down_frames = [BotStoppedSpeakingFrame, *self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -468,14 +453,14 @@ class BaseTestAssistantContextAggreagator:
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
TextFrame(text="Pipecat. "),
|
||||
TextFrame(text="How are "),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
TTSTextFrame(text="Hello "),
|
||||
TTSTextFrame(text="Pipecat. "),
|
||||
TTSTextFrame(text="How are "),
|
||||
TTSTextFrame(text="you?"),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_down_frames = [BotStoppedSpeakingFrame, *self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -490,14 +475,14 @@ class BaseTestAssistantContextAggreagator:
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello"),
|
||||
TextFrame(text="Pipecat."),
|
||||
TextFrame(text="How are"),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
TTSTextFrame(text="Hello"),
|
||||
TTSTextFrame(text="Pipecat."),
|
||||
TTSTextFrame(text="How are"),
|
||||
TTSTextFrame(text="you?"),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_down_frames = [BotStoppedSpeakingFrame, *self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -512,16 +497,21 @@ class BaseTestAssistantContextAggreagator:
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
TextFrame(text="Pipecat."),
|
||||
LLMFullResponseEndFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="How are "),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
TTSTextFrame(text="Hello "),
|
||||
TTSTextFrame(text="Pipecat."),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
TTSTextFrame(text="How are "),
|
||||
TTSTextFrame(text="you?"),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
BotStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
BotStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -537,20 +527,22 @@ class BaseTestAssistantContextAggreagator:
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
TextFrame(text="Pipecat."),
|
||||
LLMFullResponseEndFrame(),
|
||||
TTSTextFrame(text="Hello "),
|
||||
TTSTextFrame(text="Pipecat."),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(AGGREGATION_SLEEP),
|
||||
StartInterruptionFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="How are "),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
TTSTextFrame(text="How are "),
|
||||
TTSTextFrame(text="you?"),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
BotStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
StartInterruptionFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
|
||||
Reference in New Issue
Block a user