Merge pull request #1209 from pipecat-ai/aleix/reimplement-llm-response-aggregators
reimplement LLM response aggregators
This commit is contained in:
47
CHANGELOG.md
47
CHANGELOG.md
@@ -9,14 +9,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added new frames `EmulateUserStartedSpeakingFrame` and
|
||||
`EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior
|
||||
without VAD being present or not being triggered.
|
||||
|
||||
- Added a new `audio_in_stream_on_start` field to `TransportParams`.
|
||||
|
||||
- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`.
|
||||
- This method should be used to start receiving the input audio in case the field `audio_in_stream_on_start` is set to `false`.
|
||||
|
||||
- Added support for the `RTVIProcessor` to handle buffered audio in `base64` format, converting it into InputAudioRawFrame for transport.
|
||||
- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`.
|
||||
|
||||
- Added support for the `RTVIProcessor` to trigger `start_audio_in_streaming` only after the `client-ready` message.
|
||||
- This method should be used to start receiving the input audio in case the
|
||||
field `audio_in_stream_on_start` is set to `false`.
|
||||
|
||||
- Added support for the `RTVIProcessor` to handle buffered audio in `base64`
|
||||
format, converting it into InputAudioRawFrame for transport.
|
||||
|
||||
- Added support for the `RTVIProcessor` to trigger `start_audio_in_streaming`
|
||||
only after the `client-ready` message.
|
||||
|
||||
- Added new `MUTE_UNTIL_FIRST_BOT_COMPLETE` strategy to `STTMuteStrategy`. This
|
||||
strategy starts muted and remains muted until the first bot speech completes,
|
||||
@@ -38,18 +46,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
OpenAI-compatible interface. Also, added foundational example
|
||||
`14n-function-calling-perplexity.py`.
|
||||
|
||||
- Added `DailyTransport.update_remote_participants()`. This allows you to
|
||||
update remote participant's settings, like their permissions or which of
|
||||
their devices are enabled. Requires that the local participant have
|
||||
participant admin permission.
|
||||
- Added `DailyTransport.update_remote_participants()`. This allows you to update
|
||||
remote participant's settings, like their permissions or which of their
|
||||
devices are enabled. Requires that the local participant have participant
|
||||
admin permission.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `DailyTransport` to respect the `audio_in_stream_on_start` field, ensuring it only starts receiving the audio input if it is enabled.
|
||||
- Updated `DailyTransport` to respect the `audio_in_stream_on_start` field,
|
||||
ensuring it only starts receiving the audio input if it is enabled.
|
||||
|
||||
- Updated `FastAPIWebsocketOutputTransport` to send `TransportMessageFrame` and `TransportMessageUrgentFrame` to the serializer.
|
||||
- Updated `FastAPIWebsocketOutputTransport` to send `TransportMessageFrame` and
|
||||
`TransportMessageUrgentFrame` to the serializer.
|
||||
|
||||
- Updated `WebsocketServerOutputTransport` to send `TransportMessageFrame` and `TransportMessageUrgentFrame` to the serializer.
|
||||
- Updated `WebsocketServerOutputTransport` to send `TransportMessageFrame` and
|
||||
`TransportMessageUrgentFrame` to the serializer.
|
||||
|
||||
- Enhanced `STTMuteConfig` to validate strategy combinations, preventing
|
||||
`MUTE_UNTIL_FIRST_BOT_COMPLETE` and `FIRST_SPEECH` from being used together
|
||||
@@ -91,6 +102,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `BaseOutputTransport` issue that was causing upstream frames to no be
|
||||
pushed upstream.
|
||||
|
||||
- Fixed multiple issue where user transcriptions where not being handled
|
||||
properly. It was possible for short utterances to not trigger VAD which would
|
||||
cause user transcriptions to be ignored. It was also possible for one or more
|
||||
transcriptions to be generated after VAD in which case they would also be
|
||||
ignored.
|
||||
|
||||
- Fixed an issue that was causing `BotStoppedSpeakingFrame` to be generated too
|
||||
late. This could then cause issues unblocking `STTMuteFilter` later than
|
||||
desired.
|
||||
@@ -283,7 +303,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Added `enable_recording` and `geo` parameters to `DailyRoomProperties`.
|
||||
|
||||
- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings to a custom AWS bucket.
|
||||
- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings
|
||||
to a custom AWS bucket.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -497,7 +497,7 @@ class UserAggregatorBuffer(LLMResponseAggregator):
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
self._transcription = ""
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if self._aggregation:
|
||||
self._transcription = self._aggregation
|
||||
self._aggregation = ""
|
||||
|
||||
@@ -55,7 +55,7 @@ elevenlabs = [ "websockets~=13.1" ]
|
||||
fal = [ "fal-client~=0.5.6" ]
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||
gladia = [ "websockets~=13.1" ]
|
||||
google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.24.0", "google-genai~=1.0.0", "google-cloud-speech~=2.30.0" ]
|
||||
google = [ "google-cloud-speech~=2.31.0", "google-cloud-texttospeech~=2.25.0", "google-genai~=1.2.0", "google-generativeai~=0.8.4" ]
|
||||
grok = [ "openai~=1.59.6" ]
|
||||
groq = [ "openai~=1.59.6" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
|
||||
@@ -565,6 +565,22 @@ class UserStoppedSpeakingFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmulateUserStartedSpeakingFrame(SystemFrame):
|
||||
"""Emitted by internal processors upstream to emulate VAD behavior when a
|
||||
user starts speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmulateUserStoppedSpeakingFrame(SystemFrame):
|
||||
"""Emitted by internal processors upstream to emulate VAD behavior when a
|
||||
user stops speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(SystemFrame):
|
||||
"""Emitted by when the bot should be interrupted. This will mainly cause the
|
||||
|
||||
@@ -4,9 +4,17 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import List, Optional, Type
|
||||
import asyncio
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -15,6 +23,7 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -28,121 +37,105 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class LLMResponseAggregator(FrameProcessor):
|
||||
class BaseLLMResponseAggregator(FrameProcessor):
|
||||
"""This is the base class for all LLM response aggregators. These
|
||||
aggregators process incoming frames and aggregate content until they are
|
||||
ready to push the aggregation. In the case of a user, an aggregation might
|
||||
be a full transcription received from the STT service.
|
||||
|
||||
The LLM response aggregators also keep a store (e.g. a message list or an
|
||||
LLM context) of the current conversation, that is, it stores the messages
|
||||
said by the user or by the bot.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def messages(self) -> List[dict]:
|
||||
"""Returns the messages from the current conversation."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def role(self) -> str:
|
||||
"""Returns the role (e.g. user, assistant...) for this aggregator."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_messages(self, messages):
|
||||
"""Add the given messages to the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_messages(self, messages):
|
||||
"""Reset the conversation with the given messages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_tools(self, tools):
|
||||
"""Set LLM tools to be used in the current conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""Reset the internals of this aggregator. This should not modify the
|
||||
internal messages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def push_aggregation(self):
|
||||
pass
|
||||
|
||||
|
||||
class LLMResponseAggregator(BaseLLMResponseAggregator):
|
||||
"""This is a base LLM aggregator that uses a simple list of messages to
|
||||
store the conversation. It pushes `LLMMessagesFrame` as an aggregation
|
||||
frame.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
messages: List[dict],
|
||||
role: str,
|
||||
start_frame,
|
||||
end_frame,
|
||||
accumulator_frame: Type[TextFrame],
|
||||
interim_accumulator_frame: Optional[Type[TextFrame]] = None,
|
||||
handle_interruptions: bool = False,
|
||||
expect_stripped_words: bool = True, # if True, need to add spaces between words
|
||||
role: str = "user",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._messages = messages
|
||||
self._role = role
|
||||
self._start_frame = start_frame
|
||||
self._end_frame = end_frame
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._handle_interruptions = handle_interruptions
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
self._aggregation = ""
|
||||
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
def messages(self) -> List[dict]:
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def role(self):
|
||||
def role(self) -> str:
|
||||
return self._role
|
||||
|
||||
#
|
||||
# Frame processor
|
||||
#
|
||||
def add_messages(self, messages):
|
||||
self._messages.extend(messages)
|
||||
|
||||
# Use cases implemented:
|
||||
#
|
||||
# S: Start, E: End, T: Transcription, I: Interim, X: Text
|
||||
#
|
||||
# S E -> None
|
||||
# S T E -> X
|
||||
# S I T E -> X
|
||||
# S I E T -> X
|
||||
# S I E I T -> X
|
||||
# S E T -> X
|
||||
# S E I T -> X
|
||||
#
|
||||
# The following case would not be supported:
|
||||
#
|
||||
# S I E T1 I T2 -> X
|
||||
#
|
||||
# and T2 would be dropped.
|
||||
def set_messages(self, messages):
|
||||
self.reset()
|
||||
self._messages.clear()
|
||||
self._messages.extend(messages)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
def set_tools(self, tools):
|
||||
pass
|
||||
|
||||
send_aggregation = False
|
||||
def reset(self):
|
||||
self._aggregation = ""
|
||||
|
||||
if isinstance(frame, self._start_frame):
|
||||
self._aggregation = ""
|
||||
self._aggregating = True
|
||||
self._seen_start_frame = True
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._end_frame):
|
||||
self._seen_end_frame = True
|
||||
self._seen_start_frame = False
|
||||
|
||||
# We might have received the end frame but we might still be
|
||||
# aggregating (i.e. we have seen interim results but not the final
|
||||
# text).
|
||||
self._aggregating = self._seen_interim_results or len(self._aggregation) == 0
|
||||
|
||||
# Send the aggregation if we are not aggregating anymore (i.e. no
|
||||
# more interim results received).
|
||||
send_aggregation = not self._aggregating
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._accumulator_frame):
|
||||
if self._aggregating:
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
# We have recevied a complete sentence, so if we have seen the
|
||||
# end frame and we were still aggregating, it means we should
|
||||
# send the aggregation.
|
||||
send_aggregation = self._seen_end_frame
|
||||
|
||||
# We just got our final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
|
||||
self._seen_interim_results = True
|
||||
elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame):
|
||||
await self._push_aggregation()
|
||||
# Reset anyways
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self._add_messages(frame.messages)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
self._set_messages(frame.messages)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self._set_tools(frame.tools)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if send_aggregation:
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._messages.append({"role": self._role, "content": self._aggregation})
|
||||
|
||||
@@ -153,109 +146,27 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# TODO-CB: Types
|
||||
def _add_messages(self, messages):
|
||||
self._messages.extend(messages)
|
||||
|
||||
def _set_messages(self, messages):
|
||||
self._reset()
|
||||
self._messages.clear()
|
||||
self._messages.extend(messages)
|
||||
class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
"""This is a base LLM aggregator that uses an LLM context to store the
|
||||
conversation. It pushes `OpenAILLMContextFrame` as an aggregation frame.
|
||||
|
||||
def _set_tools(self, tools):
|
||||
# noop in the base class
|
||||
pass
|
||||
|
||||
def _reset(self):
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
|
||||
|
||||
class LLMAssistantResponseAggregator(LLMResponseAggregator):
|
||||
def __init__(self, messages: List[dict] = []):
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
role="assistant",
|
||||
start_frame=LLMFullResponseStartFrame,
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
)
|
||||
|
||||
|
||||
class LLMUserResponseAggregator(LLMResponseAggregator):
|
||||
def __init__(self, messages: List[dict] = []):
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
role="user",
|
||||
start_frame=UserStartedSpeakingFrame,
|
||||
end_frame=UserStoppedSpeakingFrame,
|
||||
accumulator_frame=TranscriptionFrame,
|
||||
interim_accumulator_frame=InterimTranscriptionFrame,
|
||||
)
|
||||
|
||||
|
||||
class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""This class aggregates Text frames until it receives a
|
||||
LLMFullResponseEndFrame, then emits the concatenated text as
|
||||
a single text frame.
|
||||
|
||||
given the following frames:
|
||||
|
||||
TextFrame("Hello,")
|
||||
TextFrame(" world.")
|
||||
TextFrame(" I am")
|
||||
TextFrame(" an LLM.")
|
||||
LLMFullResponseEndFrame()]
|
||||
|
||||
this processor will yield nothing for the first 4 frames, then
|
||||
|
||||
TextFrame("Hello, world. I am an LLM.")
|
||||
LLMFullResponseEndFrame()
|
||||
|
||||
when passed the last frame.
|
||||
|
||||
>>> async def print_frames(aggregator, frame):
|
||||
... async for frame in aggregator.process_frame(frame):
|
||||
... if isinstance(frame, TextFrame):
|
||||
... print(frame.text)
|
||||
... else:
|
||||
... print(frame.__class__.__name__)
|
||||
|
||||
>>> aggregator = LLMFullResponseAggregator()
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
|
||||
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
|
||||
Hello, world. I am an LLM.
|
||||
LLMFullResponseEndFrame
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._aggregation = ""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
self._aggregation += frame.text
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self.push_frame(TextFrame(self._aggregation))
|
||||
await self.push_frame(frame)
|
||||
self._aggregation = ""
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class LLMContextAggregator(LLMResponseAggregator):
|
||||
def __init__(self, *, context: OpenAILLMContext, **kwargs):
|
||||
def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._context = context
|
||||
self._role = role
|
||||
|
||||
self._aggregation = ""
|
||||
|
||||
@property
|
||||
def messages(self) -> List[dict]:
|
||||
return self._context.get_messages()
|
||||
|
||||
@property
|
||||
def role(self) -> str:
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
@@ -264,23 +175,25 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
def get_context_frame(self) -> OpenAILLMContextFrame:
|
||||
return OpenAILLMContextFrame(context=self._context)
|
||||
|
||||
async def push_context_frame(self):
|
||||
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
frame = self.get_context_frame()
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# TODO-CB: Types
|
||||
def _add_messages(self, messages):
|
||||
def add_messages(self, messages):
|
||||
self._context.add_messages(messages)
|
||||
|
||||
def _set_messages(self, messages):
|
||||
def set_messages(self, messages):
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def _set_tools(self, tools: List):
|
||||
def set_tools(self, tools: List):
|
||||
self._context.set_tools(tools)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
def reset(self):
|
||||
self._aggregation = ""
|
||||
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message({"role": self._role, "content": self._aggregation})
|
||||
self._context.add_message({"role": self.role, "content": self._aggregation})
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
@@ -290,31 +203,236 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
role="assistant",
|
||||
start_frame=LLMFullResponseStartFrame,
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
expect_stripped_words=expect_stripped_words,
|
||||
)
|
||||
class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
"""This is a user LLM aggregator that uses an LLM context to store the
|
||||
conversation. It aggregates transcriptions from the STT service and it has
|
||||
logic to handle multiple scenarios where transcriptions are received between
|
||||
VAD events (`UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame`) or
|
||||
even outside or no VAD events at all.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
aggregation_timeout: float = 1.0,
|
||||
bot_interruption_timeout: float = 2.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._aggregation_timeout = aggregation_timeout
|
||||
self._bot_interruption_timeout = bot_interruption_timeout
|
||||
|
||||
self._seen_interim_results = False
|
||||
self._user_speaking = False
|
||||
self._last_user_speaking_time = 0
|
||||
self._emulating_vad = False
|
||||
|
||||
self._aggregation_event = asyncio.Event()
|
||||
self._aggregation_task = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self._seen_interim_results = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._stop(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_interim_transcription(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
self._create_aggregation_task()
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
|
||||
self._user_speaking = True
|
||||
|
||||
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
|
||||
self._last_user_speaking_time = time.time()
|
||||
self._user_speaking = False
|
||||
if not self._seen_interim_results:
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
|
||||
self._seen_interim_results = True
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
def _create_aggregation_task(self):
|
||||
self._aggregation_task = self.create_task(self._aggregation_task_handler())
|
||||
|
||||
async def _cancel_aggregation_task(self):
|
||||
if self._aggregation_task:
|
||||
await self.cancel_task(self._aggregation_task)
|
||||
self._aggregation_task = None
|
||||
|
||||
async def _aggregation_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout)
|
||||
await self._maybe_push_bot_interruption()
|
||||
except asyncio.TimeoutError:
|
||||
if not self._user_speaking:
|
||||
await self.push_aggregation()
|
||||
|
||||
# If we are emulating VAD we still need to send the user stopped
|
||||
# speaking frame.
|
||||
if self._emulating_vad:
|
||||
await self.push_frame(
|
||||
EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM
|
||||
)
|
||||
self._emulating_vad = False
|
||||
finally:
|
||||
self._aggregation_event.clear()
|
||||
|
||||
async def _maybe_push_bot_interruption(self):
|
||||
"""If the user stopped speaking a while back and we got a transcription
|
||||
frame we might want to interrupt the bot.
|
||||
|
||||
"""
|
||||
if not self._user_speaking:
|
||||
diff_time = time.time() - self._last_user_speaking_time
|
||||
if diff_time > self._bot_interruption_timeout:
|
||||
# If we reach this case we received a transcription but VAD was
|
||||
# not able to detect voice (e.g. when you whisper a short
|
||||
# utterance). So, we need to emulate VAD (i.e. user
|
||||
# start/stopped speaking).
|
||||
await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
self._emulating_vad = True
|
||||
|
||||
# Reset time so we don't interrupt again right away.
|
||||
self._last_user_speaking_time = time.time()
|
||||
|
||||
|
||||
class LLMUserContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
role="user",
|
||||
start_frame=UserStartedSpeakingFrame,
|
||||
end_frame=UserStoppedSpeakingFrame,
|
||||
accumulator_frame=TranscriptionFrame,
|
||||
interim_accumulator_frame=InterimTranscriptionFrame,
|
||||
)
|
||||
class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
"""This is an assistant LLM aggregator that uses an LLM context to store the
|
||||
conversation. It aggregates text frames received between
|
||||
`LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs):
|
||||
super().__init__(context=context, role="assistant", **kwargs)
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
self.reset()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self.push_aggregation()
|
||||
# Reset anyways
|
||||
self.reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_start(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_llm_end(frame)
|
||||
elif isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self.add_messages(frame.messages)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
self.set_messages(frame.messages)
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
self.set_tools(frame.tools)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
|
||||
self._started = True
|
||||
|
||||
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
|
||||
self._started = False
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
|
||||
|
||||
class LLMUserResponseAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, messages: List[dict] = [], **kwargs):
|
||||
super().__init__(context=OpenAILLMContext(messages), **kwargs)
|
||||
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message({"role": self.role, "content": self._aggregation})
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
frame = LLMMessagesFrame(self._context.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self.reset()
|
||||
|
||||
|
||||
class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, messages: List[dict], **kwargs):
|
||||
super().__init__(context=OpenAILLMContext(messages), **kwargs)
|
||||
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message({"role": self.role, "content": self._aggregation})
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
frame = LLMMessagesFrame(self._context.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self.reset()
|
||||
|
||||
@@ -4,131 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.processors.aggregators.llm_response import LLMUserResponseAggregator
|
||||
|
||||
|
||||
class ResponseAggregator(FrameProcessor):
|
||||
"""This frame processor aggregates frames between a start and an end frame
|
||||
into complete text frame sentences.
|
||||
class UserResponseAggregator(LLMUserResponseAggregator):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
For example, frame input/output:
|
||||
UserStartedSpeakingFrame() -> None
|
||||
TranscriptionFrame("Hello,") -> None
|
||||
TranscriptionFrame(" world.") -> None
|
||||
UserStoppedSpeakingFrame() -> TextFrame("Hello world.")
|
||||
|
||||
Doctest: FIXME to work with asyncio
|
||||
>>> async def print_frames(aggregator, frame):
|
||||
... async for frame in aggregator.process_frame(frame):
|
||||
... if isinstance(frame, TextFrame):
|
||||
... print(frame.text)
|
||||
|
||||
>>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame,
|
||||
... end_frame=UserStoppedSpeakingFrame,
|
||||
... accumulator_frame=TranscriptionFrame,
|
||||
... pass_through=False)
|
||||
>>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame()))
|
||||
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1)))
|
||||
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2)))
|
||||
>>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame()))
|
||||
Hello, world.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
start_frame,
|
||||
end_frame,
|
||||
accumulator_frame: TextFrame,
|
||||
interim_accumulator_frame: Optional[TextFrame] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._start_frame = start_frame
|
||||
self._end_frame = end_frame
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
|
||||
#
|
||||
# Frame processor
|
||||
#
|
||||
|
||||
# Use cases implemented:
|
||||
#
|
||||
# S: Start, E: End, T: Transcription, I: Interim, X: Text
|
||||
#
|
||||
# S E -> None
|
||||
# S T E -> X
|
||||
# S I T E -> X
|
||||
# S I E T -> X
|
||||
# S I E I T -> X
|
||||
# S E T -> X
|
||||
# S E I T -> X
|
||||
#
|
||||
# The following case would not be supported:
|
||||
#
|
||||
# S I E T1 I T2 -> X
|
||||
#
|
||||
# and T2 would be dropped.
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
send_aggregation = False
|
||||
|
||||
if isinstance(frame, self._start_frame):
|
||||
self._aggregating = True
|
||||
self._seen_start_frame = True
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._end_frame):
|
||||
self._seen_end_frame = True
|
||||
self._seen_start_frame = False
|
||||
|
||||
# We might have received the end frame but we might still be
|
||||
# aggregating (i.e. we have seen interim results but not the final
|
||||
# text).
|
||||
self._aggregating = self._seen_interim_results or len(self._aggregation) == 0
|
||||
|
||||
# Send the aggregation if we are not aggregating anymore (i.e. no
|
||||
# more interim results received).
|
||||
send_aggregation = not self._aggregating
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._accumulator_frame):
|
||||
if self._aggregating:
|
||||
self._aggregation += f" {frame.text}"
|
||||
# We have recevied a complete sentence, so if we have seen the
|
||||
# end frame and we were still aggregating, it means we should
|
||||
# send the aggregation.
|
||||
send_aggregation = self._seen_end_frame
|
||||
|
||||
# We just got our final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
|
||||
self._seen_interim_results = True
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if send_aggregation:
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
frame = TextFrame(self._aggregation.strip())
|
||||
|
||||
@@ -139,21 +23,4 @@ class ResponseAggregator(FrameProcessor):
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
self._aggregation = ""
|
||||
self._aggregating = False
|
||||
self._seen_start_frame = False
|
||||
self._seen_end_frame = False
|
||||
self._seen_interim_results = False
|
||||
|
||||
|
||||
class UserResponseAggregator(ResponseAggregator):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
start_frame=UserStartedSpeakingFrame,
|
||||
end_frame=UserStoppedSpeakingFrame,
|
||||
accumulator_frame=TranscriptionFrame,
|
||||
interim_accumulator_frame=InterimTranscriptionFrame,
|
||||
)
|
||||
self.reset()
|
||||
|
||||
@@ -577,7 +577,7 @@ class SegmentedSTTService(STTService):
|
||||
self._smoothing_factor = 0.2
|
||||
self._prev_volume = 0
|
||||
|
||||
async def process_audio_frame(self, frame: AudioRawFrame):
|
||||
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
|
||||
# Try to filter out empty background noise
|
||||
volume = self._get_smoothed_volume(frame)
|
||||
if volume >= self._min_volume:
|
||||
|
||||
@@ -126,9 +126,11 @@ class AnthropicLLMService(LLMService):
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> AnthropicContextAggregatorPair:
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = AnthropicLLMContext.from_openai_context(context)
|
||||
user = AnthropicUserContextAggregator(context)
|
||||
assistant = AnthropicAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
context, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -651,11 +653,8 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
|
||||
|
||||
class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext):
|
||||
super().__init__(context=context)
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
self._context = AnthropicLLMContext.from_openai_context(context)
|
||||
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs):
|
||||
super().__init__(context=context, **kwargs)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -703,9 +702,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs):
|
||||
super().__init__(context=context, **kwargs)
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
self._pending_image_frame_message = None
|
||||
@@ -725,7 +723,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
):
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = frame
|
||||
await self._push_aggregation()
|
||||
await self.push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id"
|
||||
@@ -734,9 +732,9 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._function_call_result = None
|
||||
elif isinstance(frame, AnthropicImageMessageFrame):
|
||||
self._pending_image_frame_message = frame
|
||||
await self._push_aggregation()
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
@@ -746,7 +744,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -799,15 +797,14 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
await properties.on_context_updated()
|
||||
|
||||
# Push context frame
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
await self.push_context_frame()
|
||||
|
||||
# Push timestamp frame with current time
|
||||
timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
|
||||
@@ -115,10 +115,10 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
# We don't want to store any images in the context. Revisit this later when the API evolves.
|
||||
self._pending_image_frame_message = None
|
||||
await super()._push_aggregation()
|
||||
await super().push_aggregation()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -706,6 +706,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
GeminiMultimodalLiveContext.upgrade(context)
|
||||
user = GeminiMultimodalLiveUserContextAggregator(context)
|
||||
assistant = GeminiMultimodalLiveAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
context, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -537,7 +537,7 @@ def language_to_google_stt_language(language: Language) -> Optional[str]:
|
||||
|
||||
|
||||
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message(
|
||||
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
|
||||
@@ -552,11 +552,11 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
|
||||
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
@@ -566,7 +566,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -626,15 +626,14 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
await properties.on_context_updated()
|
||||
|
||||
# Push context frame
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
await self.push_context_frame()
|
||||
|
||||
# Push timestamp frame with current time
|
||||
timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
@@ -1175,7 +1174,7 @@ class GoogleLLMService(LLMService):
|
||||
) -> GoogleContextAggregatorPair:
|
||||
user = GoogleUserContextAggregator(context)
|
||||
assistant = GoogleAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
context, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMService,
|
||||
@@ -27,7 +28,7 @@ from pipecat.services.openai import (
|
||||
class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Custom assistant context aggregator for Grok that handles empty content requirement."""
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
@@ -37,7 +38,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -91,14 +92,13 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
await properties.on_context_updated()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
await self.push_context_frame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
@@ -212,6 +212,6 @@ class GrokLLMService(OpenAILLMService):
|
||||
) -> GrokContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = GrokAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
context, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return GrokContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -355,7 +355,7 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = OpenAIAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
context, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -555,8 +555,8 @@ class OpenAIImageMessageFrame(Frame):
|
||||
|
||||
|
||||
class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
super().__init__(context=context)
|
||||
def __init__(self, context: OpenAILLMContext, **kwargs):
|
||||
super().__init__(context=context, **kwargs)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -592,9 +592,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
def __init__(self, context: OpenAILLMContext, **kwargs):
|
||||
super().__init__(context=context, **kwargs)
|
||||
self._function_calls_in_progress = {}
|
||||
self._function_call_result = None
|
||||
self._pending_image_frame_message = None
|
||||
@@ -614,7 +613,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
self._function_call_result = frame
|
||||
# TODO-CB: Kwin wants us to refactor this out of here but I REFUSE
|
||||
await self._push_aggregation()
|
||||
await self.push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
|
||||
@@ -622,9 +621,9 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._function_call_result = None
|
||||
elif isinstance(frame, OpenAIImageMessageFrame):
|
||||
self._pending_image_frame_message = frame
|
||||
await self._push_aggregation()
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
@@ -634,7 +633,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._reset()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -686,15 +685,14 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
await properties.on_context_updated()
|
||||
|
||||
# Push context frame
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
await self.push_context_frame()
|
||||
|
||||
# Push timestamp frame with current time
|
||||
timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
|
||||
|
||||
@@ -166,7 +166,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
if isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
# upstream (transport/transcription, or other sources)
|
||||
@@ -174,7 +174,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
async def push_aggregation(self):
|
||||
# the only thing we implement here is function calling. in all other cases, messages
|
||||
# are added to the context when we receive openai realtime api events
|
||||
if not self._function_call_result:
|
||||
@@ -182,7 +182,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
|
||||
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
self._reset()
|
||||
self.reset()
|
||||
try:
|
||||
run_llm = True
|
||||
frame = self._function_call_result
|
||||
@@ -217,8 +217,8 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
|
||||
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
|
||||
# special frame to do that.
|
||||
await self._user_context_aggregator.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame)
|
||||
await self.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
if properties and properties.run_llm is not None:
|
||||
# If the tool call result has a run_llm property, use it
|
||||
@@ -228,14 +228,13 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
await properties.on_context_updated()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
await self.push_context_frame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -568,6 +568,6 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
OpenAIRealtimeLLMContext.upgrade_to_realtime(context)
|
||||
user = OpenAIRealtimeUserContextAggregator(context)
|
||||
assistant = OpenAIRealtimeAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
context, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, Sequence, Tuple
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -12,6 +13,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
HeartbeatFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
@@ -20,6 +22,15 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class SleepFrame(SystemFrame):
|
||||
"""This frame is used by test framework to introduce some sleep time before
|
||||
the next frame is pushed. This is useful to control system frames vs data or
|
||||
control frames."""
|
||||
|
||||
sleep: float = 0.1
|
||||
|
||||
|
||||
class HeartbeatsObserver(BaseObserver):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,7 +55,11 @@ class HeartbeatsObserver(BaseObserver):
|
||||
|
||||
class QueuedFrameProcessor(FrameProcessor):
|
||||
def __init__(
|
||||
self, queue: asyncio.Queue, queue_direction: FrameDirection, ignore_start: bool = True
|
||||
self,
|
||||
*,
|
||||
queue: asyncio.Queue,
|
||||
queue_direction: FrameDirection,
|
||||
ignore_start: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self._queue = queue
|
||||
@@ -72,21 +87,35 @@ async def run_test(
|
||||
) -> Tuple[Sequence[Frame], Sequence[Frame]]:
|
||||
received_up = asyncio.Queue()
|
||||
received_down = asyncio.Queue()
|
||||
source = QueuedFrameProcessor(received_up, FrameDirection.UPSTREAM, ignore_start)
|
||||
sink = QueuedFrameProcessor(received_down, FrameDirection.DOWNSTREAM, ignore_start)
|
||||
source = QueuedFrameProcessor(
|
||||
queue=received_up,
|
||||
queue_direction=FrameDirection.UPSTREAM,
|
||||
ignore_start=ignore_start,
|
||||
)
|
||||
sink = QueuedFrameProcessor(
|
||||
queue=received_down,
|
||||
queue_direction=FrameDirection.DOWNSTREAM,
|
||||
ignore_start=ignore_start,
|
||||
)
|
||||
|
||||
pipeline = Pipeline([source, processor, sink])
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(start_metadata=start_metadata))
|
||||
|
||||
for frame in frames_to_send:
|
||||
await task.queue_frame(frame)
|
||||
async def push_frames():
|
||||
# Just give a little head start to the runner.
|
||||
await asyncio.sleep(0.01)
|
||||
for frame in frames_to_send:
|
||||
if isinstance(frame, SleepFrame):
|
||||
await asyncio.sleep(frame.sleep)
|
||||
else:
|
||||
await task.queue_frame(frame)
|
||||
|
||||
if send_end_frame:
|
||||
await task.queue_frame(EndFrame())
|
||||
if send_end_frame:
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
await runner.run(task)
|
||||
await asyncio.gather(runner.run(task), push_frames())
|
||||
|
||||
#
|
||||
# Down frames
|
||||
@@ -98,6 +127,7 @@ async def run_test(
|
||||
received_down_frames.append(frame)
|
||||
|
||||
print("received DOWN frames =", received_down_frames)
|
||||
print("expected DOWN frames =", expected_down_frames)
|
||||
|
||||
assert len(received_down_frames) == len(expected_down_frames)
|
||||
|
||||
@@ -113,6 +143,7 @@ async def run_test(
|
||||
received_up_frames.append(frame)
|
||||
|
||||
print("received UP frames =", received_up_frames)
|
||||
print("expected UP frames =", expected_up_frames)
|
||||
|
||||
assert len(received_up_frames) == len(expected_up_frames)
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
FilterUpdateSettingsFrame,
|
||||
Frame,
|
||||
@@ -112,9 +114,13 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
await self._handle_bot_interruption(frame)
|
||||
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
|
||||
logger.debug("Emulating user started speaking")
|
||||
await self._handle_user_interruption(UserStartedSpeakingFrame())
|
||||
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
|
||||
logger.debug("Emulating user stopped speaking")
|
||||
await self._handle_user_interruption(UserStoppedSpeakingFrame())
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -137,7 +143,13 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Handle interruptions
|
||||
#
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
async def _handle_bot_interruption(self, frame: BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
if self.interruptions_allowed:
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
|
||||
async def _handle_user_interruption(self, frame: Frame):
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
logger.debug("User started speaking")
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
@@ -183,7 +195,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
frame = UserStoppedSpeakingFrame()
|
||||
|
||||
if frame:
|
||||
await self._handle_interruptions(frame)
|
||||
await self._handle_user_interruption(frame)
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
|
||||
@@ -170,6 +170,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# TODO(aleix): Images and audio should support presentation timestamps.
|
||||
elif frame.pts:
|
||||
await self._sink_clock_queue.put((frame.pts, frame.id, frame))
|
||||
elif direction == FrameDirection.UPSTREAM:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self._sink_queue.put(frame)
|
||||
|
||||
|
||||
@@ -7,8 +7,10 @@ deepgram-sdk~=3.5.0
|
||||
fal-client~=0.4.1
|
||||
fastapi~=0.115.0
|
||||
faster-whisper~=1.0.3
|
||||
google-cloud-texttospeech~=2.21.1
|
||||
google-generativeai~=0.8.3
|
||||
google-cloud-speech~=2.31.0
|
||||
google-cloud-texttospeech~=2.25.0
|
||||
google-genai~=1.2.0
|
||||
google-generativeai~=0.8.4
|
||||
langchain~=0.2.14
|
||||
livekit~=0.13.1
|
||||
lmnt~=1.1.4
|
||||
|
||||
@@ -29,12 +29,12 @@ class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
for word in sentence.split(" "):
|
||||
frames_to_send.append(TextFrame(text=word + " "))
|
||||
|
||||
expected_returned_frames = [TextFrame, TextFrame, TextFrame]
|
||||
expected_down_frames = [TextFrame, TextFrame, TextFrame]
|
||||
|
||||
(received_down, _) = await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert received_down[-3].text == "Hello, world. "
|
||||
assert received_down[-2].text == "How are you? "
|
||||
@@ -59,7 +59,7 @@ class TestGatedAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
expected_down_frames = [
|
||||
OutputImageRawFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TextFrame,
|
||||
@@ -72,5 +72,5 @@ class TestGatedAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
(received_down, _) = await run_test(
|
||||
gated_aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
671
tests/test_context_aggregators.py
Normal file
671
tests/test_context_aggregators.py
Normal file
@@ -0,0 +1,671 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.services.anthropic import (
|
||||
AnthropicAssistantContextAggregator,
|
||||
AnthropicLLMContext,
|
||||
AnthropicUserContextAggregator,
|
||||
)
|
||||
from pipecat.services.google.google import (
|
||||
GoogleAssistantContextAggregator,
|
||||
GoogleLLMContext,
|
||||
GoogleUserContextAggregator,
|
||||
)
|
||||
from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
AGGREGATION_TIMEOUT = 0.1
|
||||
AGGREGATION_SLEEP = 0.15
|
||||
BOT_INTERRUPTION_TIMEOUT = 0.2
|
||||
BOT_INTERRUPTION_SLEEP = 0.25
|
||||
|
||||
|
||||
class BaseTestUserContextAggregator:
|
||||
CONTEXT_CLASS = None # To be set in subclasses
|
||||
AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
assert context.messages[index]["content"] == content
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
assert context.messages[index]["content"] == content
|
||||
|
||||
async def test_se(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()]
|
||||
expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_ste(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello!")
|
||||
|
||||
async def test_site(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello Pipecat!")
|
||||
|
||||
async def test_st1iest2e(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello Pipecat! How are you?")
|
||||
|
||||
async def test_siet(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "How are you?")
|
||||
|
||||
async def test_sieit(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "How are you?")
|
||||
|
||||
async def test_set(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "How are you?")
|
||||
|
||||
async def test_seit(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "How are you?")
|
||||
|
||||
async def test_st1et2(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_multi_content(context, 0, 0, "Hello Pipecat!")
|
||||
self.check_message_multi_content(context, 0, 1, "How are you?")
|
||||
|
||||
async def test_set1t2(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello Pipecat! How are you?")
|
||||
|
||||
async def test_siet1it2(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello Pipecat! How are you?")
|
||||
|
||||
async def test_t(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello!")
|
||||
|
||||
async def test_it(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
frames_to_send = [
|
||||
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello Pipecat!")
|
||||
|
||||
async def test_sie_delay_it(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context,
|
||||
aggregation_timeout=AGGREGATION_TIMEOUT,
|
||||
bot_interruption_timeout=BOT_INTERRUPTION_TIMEOUT,
|
||||
)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
SleepFrame(BOT_INTERRUPTION_SLEEP),
|
||||
InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""),
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "How are you?")
|
||||
|
||||
|
||||
class BaseTestAssistantContextAggreagator:
|
||||
CONTEXT_CLASS = None # To be set in subclasses
|
||||
AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
assert context.messages[index]["content"] == content
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
assert context.messages[index]["content"] == content
|
||||
|
||||
async def test_empty(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_single_text(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello Pipecat!"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello Pipecat!")
|
||||
|
||||
async def test_multiple_text(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
TextFrame(text="Pipecat. "),
|
||||
TextFrame(text="How are "),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello Pipecat. How are you?")
|
||||
|
||||
async def test_multiple_text_stripped(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello"),
|
||||
TextFrame(text="Pipecat."),
|
||||
TextFrame(text="How are"),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_content(context, 0, "Hello Pipecat. How are you?")
|
||||
|
||||
async def test_multiple_llm_responses(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
TextFrame(text="Pipecat."),
|
||||
LLMFullResponseEndFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="How are "),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_multi_content(context, 0, 0, "Hello Pipecat.")
|
||||
self.check_message_multi_content(context, 0, 1, "How are you?")
|
||||
|
||||
async def test_multiple_llm_responses_interruption(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello "),
|
||||
TextFrame(text="Pipecat."),
|
||||
LLMFullResponseEndFrame(),
|
||||
SleepFrame(AGGREGATION_SLEEP),
|
||||
StartInterruptionFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="How are "),
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
StartInterruptionFrame,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.check_message_multi_content(context, 0, 0, "Hello Pipecat.")
|
||||
self.check_message_multi_content(context, 0, 1, "How are you?")
|
||||
|
||||
|
||||
#
|
||||
# LLMUserContextAggregator, LLMAssistantContextAggregator
|
||||
#
|
||||
|
||||
|
||||
class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = LLMUserContextAggregator
|
||||
|
||||
|
||||
class TestLLMAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = LLMAssistantContextAggregator
|
||||
|
||||
|
||||
#
|
||||
# OpenAI
|
||||
#
|
||||
|
||||
|
||||
class TestOpenAIUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIUserContextAggregator
|
||||
|
||||
|
||||
class TestOpenAIAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
|
||||
#
|
||||
# Anthropic
|
||||
#
|
||||
|
||||
|
||||
class TestAnthropicUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AnthropicLLMContext
|
||||
AGGREGATOR_CLASS = AnthropicUserContextAggregator
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
messages = context.messages[content_index]
|
||||
assert messages["content"][index]["text"] == content
|
||||
|
||||
|
||||
class TestAnthropicAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AnthropicLLMContext
|
||||
AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
messages = context.messages[content_index]
|
||||
assert messages["content"][index]["text"] == content
|
||||
|
||||
|
||||
#
|
||||
# Google
|
||||
#
|
||||
|
||||
|
||||
class TestGoogleUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = GoogleLLMContext
|
||||
AGGREGATOR_CLASS = GoogleUserContextAggregator
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
assert obj["parts"][0]["text"] == content
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
assert obj["parts"][0]["text"] == content
|
||||
|
||||
|
||||
class TestGoogleAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = GoogleLLMContext
|
||||
AGGREGATOR_CLASS = GoogleAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
assert obj["parts"][0]["text"] == content
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
):
|
||||
obj = glm.Content.to_dict(context.messages[index])
|
||||
assert obj["parts"][0]["text"] == content
|
||||
@@ -25,11 +25,11 @@ class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_identity(self):
|
||||
filter = IdentityFilter()
|
||||
frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()]
|
||||
expected_returned_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
|
||||
expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
|
||||
@@ -37,32 +37,32 @@ class TestFrameFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_text_frame(self):
|
||||
filter = FrameFilter(types=(TextFrame,))
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
expected_down_frames = [TextFrame]
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_end_frame(self):
|
||||
filter = FrameFilter(types=(EndFrame,))
|
||||
frames_to_send = [EndFrame()]
|
||||
expected_returned_frames = [EndFrame]
|
||||
expected_down_frames = [EndFrame]
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
send_end_frame=False,
|
||||
)
|
||||
|
||||
async def test_system_frame(self):
|
||||
filter = FrameFilter(types=())
|
||||
frames_to_send = [UserStartedSpeakingFrame()]
|
||||
expected_returned_frames = [UserStartedSpeakingFrame]
|
||||
expected_down_frames = [UserStartedSpeakingFrame]
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,11 +73,11 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
filter = FunctionFilter(filter=passthrough)
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
expected_down_frames = [TextFrame]
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_no_passthrough(self):
|
||||
@@ -86,11 +86,11 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
filter = FunctionFilter(filter=no_passthrough)
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = []
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
|
||||
@@ -98,11 +98,11 @@ class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_no_wake_word(self):
|
||||
filter = WakeCheckFilter(wake_phrases=["Hey, Pipecat"])
|
||||
frames_to_send = [TranscriptionFrame(user_id="test", text="Phrase 1", timestamp="")]
|
||||
expected_returned_frames = []
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_wake_word(self):
|
||||
@@ -111,10 +111,10 @@ class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase):
|
||||
TranscriptionFrame(user_id="test", text="Hey, Pipecat", timestamp=""),
|
||||
TranscriptionFrame(user_id="test", text="Phrase 1", timestamp=""),
|
||||
]
|
||||
expected_returned_frames = [TranscriptionFrame, TranscriptionFrame]
|
||||
expected_down_frames = [TranscriptionFrame, TranscriptionFrame]
|
||||
(received_down, _) = await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert received_down[-1].text == "Phrase 1"
|
||||
|
||||
@@ -10,23 +10,22 @@ from langchain.prompts import ChatPromptTemplate
|
||||
from langchain_core.language_models import FakeStreamingListLLM
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.processors.frameworks.langchain import LangchainProcessor
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -64,31 +63,26 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages, expect_stripped_words=False)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
tma_in,
|
||||
proc,
|
||||
self.mock_proc,
|
||||
tma_out,
|
||||
]
|
||||
pipeline = Pipeline([tma_in, proc, self.mock_proc, tma_out])
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
LLMMessagesFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=False))
|
||||
await task.queue_frames(
|
||||
[
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
|
||||
UserStoppedSpeakingFrame(),
|
||||
EndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
runner = PipelineRunner()
|
||||
await runner.run(task)
|
||||
self.assertEqual("".join(self.mock_proc.token), self.expected_response)
|
||||
# TODO: Address this issue
|
||||
# This next one would fail with:
|
||||
# AssertionError: ' H e l l o d e a r h u m a n' != 'Hello dear human'
|
||||
# self.assertEqual(tma_out.messages[-1]["content"], self.expected_response)
|
||||
self.assertEqual(tma_out.messages[-1]["content"], self.expected_response)
|
||||
|
||||
@@ -21,11 +21,11 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = Pipeline([IdentityFilter()])
|
||||
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
expected_down_frames = [TextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_pipeline_multiple(self):
|
||||
@@ -36,22 +36,22 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = Pipeline([identity1, identity2, identity3])
|
||||
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
expected_down_frames = [TextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_pipeline_start_metadata(self):
|
||||
pipeline = Pipeline([IdentityFilter()])
|
||||
|
||||
frames_to_send = []
|
||||
expected_returned_frames = [StartFrame]
|
||||
expected_down_frames = [StartFrame]
|
||||
(received_down, _) = await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
ignore_start=False,
|
||||
start_metadata={"foo": "bar"},
|
||||
)
|
||||
@@ -63,11 +63,11 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = ParallelPipeline([IdentityFilter()])
|
||||
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
expected_down_frames = [TextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_parallel_multiple(self):
|
||||
@@ -75,11 +75,11 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = ParallelPipeline([IdentityFilter()], [IdentityFilter()])
|
||||
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
expected_down_frames = [TextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user