Merge pull request #1209 from pipecat-ai/aleix/reimplement-llm-response-aggregators

reimplement LLM response aggregators
This commit is contained in:
Aleix Conchillo Flaqué
2025-02-13 15:30:40 -08:00
committed by GitHub
23 changed files with 1235 additions and 508 deletions

View File

@@ -9,14 +9,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added new frames `EmulateUserStartedSpeakingFrame` and
`EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior
without VAD being present or not being triggered.
- Added a new `audio_in_stream_on_start` field to `TransportParams`.
- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`.
- This method should be used to start receiving the input audio in case the field `audio_in_stream_on_start` is set to `false`.
- Added support for the `RTVIProcessor` to handle buffered audio in `base64` format, converting it into InputAudioRawFrame for transport.
- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`.
- Added support for the `RTVIProcessor` to trigger `start_audio_in_streaming` only after the `client-ready` message.
- This method should be used to start receiving the input audio in case the
field `audio_in_stream_on_start` is set to `false`.
- Added support for the `RTVIProcessor` to handle buffered audio in `base64`
format, converting it into InputAudioRawFrame for transport.
- Added support for the `RTVIProcessor` to trigger `start_audio_in_streaming`
only after the `client-ready` message.
- Added new `MUTE_UNTIL_FIRST_BOT_COMPLETE` strategy to `STTMuteStrategy`. This
strategy starts muted and remains muted until the first bot speech completes,
@@ -38,18 +46,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
OpenAI-compatible interface. Also, added foundational example
`14n-function-calling-perplexity.py`.
- Added `DailyTransport.update_remote_participants()`. This allows you to
update remote participant's settings, like their permissions or which of
their devices are enabled. Requires that the local participant have
participant admin permission.
- Added `DailyTransport.update_remote_participants()`. This allows you to update
remote participant's settings, like their permissions or which of their
devices are enabled. Requires that the local participant have participant
admin permission.
### Changed
- Updated `DailyTransport` to respect the `audio_in_stream_on_start` field, ensuring it only starts receiving the audio input if it is enabled.
- Updated `DailyTransport` to respect the `audio_in_stream_on_start` field,
ensuring it only starts receiving the audio input if it is enabled.
- Updated `FastAPIWebsocketOutputTransport` to send `TransportMessageFrame` and `TransportMessageUrgentFrame` to the serializer.
- Updated `FastAPIWebsocketOutputTransport` to send `TransportMessageFrame` and
`TransportMessageUrgentFrame` to the serializer.
- Updated `WebsocketServerOutputTransport` to send `TransportMessageFrame` and `TransportMessageUrgentFrame` to the serializer.
- Updated `WebsocketServerOutputTransport` to send `TransportMessageFrame` and
`TransportMessageUrgentFrame` to the serializer.
- Enhanced `STTMuteConfig` to validate strategy combinations, preventing
`MUTE_UNTIL_FIRST_BOT_COMPLETE` and `FIRST_SPEECH` from being used together
@@ -91,6 +102,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed a `BaseOutputTransport` issue that was causing upstream frames to no be
pushed upstream.
- Fixed multiple issue where user transcriptions where not being handled
properly. It was possible for short utterances to not trigger VAD which would
cause user transcriptions to be ignored. It was also possible for one or more
transcriptions to be generated after VAD in which case they would also be
ignored.
- Fixed an issue that was causing `BotStoppedSpeakingFrame` to be generated too
late. This could then cause issues unblocking `STTMuteFilter` later than
desired.
@@ -283,7 +303,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `enable_recording` and `geo` parameters to `DailyRoomProperties`.
- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings to a custom AWS bucket.
- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings
to a custom AWS bucket.
### Changed

View File

@@ -497,7 +497,7 @@ class UserAggregatorBuffer(LLMResponseAggregator):
if isinstance(frame, UserStartedSpeakingFrame):
self._transcription = ""
async def _push_aggregation(self):
async def push_aggregation(self):
if self._aggregation:
self._transcription = self._aggregation
self._aggregation = ""

View File

@@ -55,7 +55,7 @@ elevenlabs = [ "websockets~=13.1" ]
fal = [ "fal-client~=0.5.6" ]
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.24.0", "google-genai~=1.0.0", "google-cloud-speech~=2.30.0" ]
google = [ "google-cloud-speech~=2.31.0", "google-cloud-texttospeech~=2.25.0", "google-genai~=1.2.0", "google-generativeai~=0.8.4" ]
grok = [ "openai~=1.59.6" ]
groq = [ "openai~=1.59.6" ]
gstreamer = [ "pygobject~=3.50.0" ]

View File

@@ -565,6 +565,22 @@ class UserStoppedSpeakingFrame(SystemFrame):
pass
@dataclass
class EmulateUserStartedSpeakingFrame(SystemFrame):
"""Emitted by internal processors upstream to emulate VAD behavior when a
user starts speaking."""
pass
@dataclass
class EmulateUserStoppedSpeakingFrame(SystemFrame):
"""Emitted by internal processors upstream to emulate VAD behavior when a
user stops speaking."""
pass
@dataclass
class BotInterruptionFrame(SystemFrame):
"""Emitted by when the bot should be interrupted. This will mainly cause the

View File

@@ -4,9 +4,17 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import List, Optional, Type
import asyncio
import time
from abc import abstractmethod
from typing import List
from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
@@ -15,6 +23,7 @@ from pipecat.frames.frames import (
LLMMessagesFrame,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
StartFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
@@ -28,121 +37,105 @@ from pipecat.processors.aggregators.openai_llm_context import (
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class LLMResponseAggregator(FrameProcessor):
class BaseLLMResponseAggregator(FrameProcessor):
"""This is the base class for all LLM response aggregators. These
aggregators process incoming frames and aggregate content until they are
ready to push the aggregation. In the case of a user, an aggregation might
be a full transcription received from the STT service.
The LLM response aggregators also keep a store (e.g. a message list or an
LLM context) of the current conversation, that is, it stores the messages
said by the user or by the bot.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@property
@abstractmethod
def messages(self) -> List[dict]:
"""Returns the messages from the current conversation."""
pass
@property
@abstractmethod
def role(self) -> str:
"""Returns the role (e.g. user, assistant...) for this aggregator."""
pass
@abstractmethod
def add_messages(self, messages):
"""Add the given messages to the conversation."""
pass
@abstractmethod
def set_messages(self, messages):
"""Reset the conversation with the given messages."""
pass
@abstractmethod
def set_tools(self, tools):
"""Set LLM tools to be used in the current conversation."""
pass
@abstractmethod
def reset(self):
"""Reset the internals of this aggregator. This should not modify the
internal messages."""
pass
@abstractmethod
async def push_aggregation(self):
pass
class LLMResponseAggregator(BaseLLMResponseAggregator):
"""This is a base LLM aggregator that uses a simple list of messages to
store the conversation. It pushes `LLMMessagesFrame` as an aggregation
frame.
"""
def __init__(
self,
*,
messages: List[dict],
role: str,
start_frame,
end_frame,
accumulator_frame: Type[TextFrame],
interim_accumulator_frame: Optional[Type[TextFrame]] = None,
handle_interruptions: bool = False,
expect_stripped_words: bool = True, # if True, need to add spaces between words
role: str = "user",
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self._messages = messages
self._role = role
self._start_frame = start_frame
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._handle_interruptions = handle_interruptions
self._expect_stripped_words = expect_stripped_words
# Reset our accumulator state.
self._reset()
self._aggregation = ""
self.reset()
@property
def messages(self):
def messages(self) -> List[dict]:
return self._messages
@property
def role(self):
def role(self) -> str:
return self._role
#
# Frame processor
#
def add_messages(self, messages):
self._messages.extend(messages)
# Use cases implemented:
#
# S: Start, E: End, T: Transcription, I: Interim, X: Text
#
# S E -> None
# S T E -> X
# S I T E -> X
# S I E T -> X
# S I E I T -> X
# S E T -> X
# S E I T -> X
#
# The following case would not be supported:
#
# S I E T1 I T2 -> X
#
# and T2 would be dropped.
def set_messages(self, messages):
self.reset()
self._messages.clear()
self._messages.extend(messages)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
def set_tools(self, tools):
pass
send_aggregation = False
def reset(self):
self._aggregation = ""
if isinstance(frame, self._start_frame):
self._aggregation = ""
self._aggregating = True
self._seen_start_frame = True
self._seen_end_frame = False
self._seen_interim_results = False
await self.push_frame(frame, direction)
elif isinstance(frame, self._end_frame):
self._seen_end_frame = True
self._seen_start_frame = False
# We might have received the end frame but we might still be
# aggregating (i.e. we have seen interim results but not the final
# text).
self._aggregating = self._seen_interim_results or len(self._aggregation) == 0
# Send the aggregation if we are not aggregating anymore (i.e. no
# more interim results received).
send_aggregation = not self._aggregating
await self.push_frame(frame, direction)
elif isinstance(frame, self._accumulator_frame):
if self._aggregating:
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
# We have recevied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
send_aggregation = self._seen_end_frame
# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame):
await self._push_aggregation()
# Reset anyways
self._reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMMessagesAppendFrame):
self._add_messages(frame.messages)
elif isinstance(frame, LLMMessagesUpdateFrame):
self._set_messages(frame.messages)
elif isinstance(frame, LLMSetToolsFrame):
self._set_tools(frame.tools)
else:
await self.push_frame(frame, direction)
if send_aggregation:
await self._push_aggregation()
async def _push_aggregation(self):
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._messages.append({"role": self._role, "content": self._aggregation})
@@ -153,109 +146,27 @@ class LLMResponseAggregator(FrameProcessor):
frame = LLMMessagesFrame(self._messages)
await self.push_frame(frame)
# TODO-CB: Types
def _add_messages(self, messages):
self._messages.extend(messages)
def _set_messages(self, messages):
self._reset()
self._messages.clear()
self._messages.extend(messages)
class LLMContextResponseAggregator(BaseLLMResponseAggregator):
"""This is a base LLM aggregator that uses an LLM context to store the
conversation. It pushes `OpenAILLMContextFrame` as an aggregation frame.
def _set_tools(self, tools):
# noop in the base class
pass
def _reset(self):
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = []):
super().__init__(
messages=messages,
role="assistant",
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
)
class LLMUserResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: List[dict] = []):
super().__init__(
messages=messages,
role="user",
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
)
class LLMFullResponseAggregator(FrameProcessor):
"""This class aggregates Text frames until it receives a
LLMFullResponseEndFrame, then emits the concatenated text as
a single text frame.
given the following frames:
TextFrame("Hello,")
TextFrame(" world.")
TextFrame(" I am")
TextFrame(" an LLM.")
LLMFullResponseEndFrame()]
this processor will yield nothing for the first 4 frames, then
TextFrame("Hello, world. I am an LLM.")
LLMFullResponseEndFrame()
when passed the last frame.
>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)
... else:
... print(frame.__class__.__name__)
>>> aggregator = LLMFullResponseAggregator()
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
Hello, world. I am an LLM.
LLMFullResponseEndFrame
"""
def __init__(self):
super().__init__()
self._aggregation = ""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
self._aggregation += frame.text
elif isinstance(frame, LLMFullResponseEndFrame):
await self.push_frame(TextFrame(self._aggregation))
await self.push_frame(frame)
self._aggregation = ""
else:
await self.push_frame(frame, direction)
class LLMContextAggregator(LLMResponseAggregator):
def __init__(self, *, context: OpenAILLMContext, **kwargs):
def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs):
super().__init__(**kwargs)
self._context = context
self._role = role
self._aggregation = ""
@property
def messages(self) -> List[dict]:
return self._context.get_messages()
@property
def role(self) -> str:
return self._role
@property
def context(self):
@@ -264,23 +175,25 @@ class LLMContextAggregator(LLMResponseAggregator):
def get_context_frame(self) -> OpenAILLMContextFrame:
return OpenAILLMContextFrame(context=self._context)
async def push_context_frame(self):
async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM):
frame = self.get_context_frame()
await self.push_frame(frame)
await self.push_frame(frame, direction)
# TODO-CB: Types
def _add_messages(self, messages):
def add_messages(self, messages):
self._context.add_messages(messages)
def _set_messages(self, messages):
def set_messages(self, messages):
self._context.set_messages(messages)
def _set_tools(self, tools: List):
def set_tools(self, tools: List):
self._context.set_tools(tools)
async def _push_aggregation(self):
def reset(self):
self._aggregation = ""
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message({"role": self._role, "content": self._aggregation})
self._context.add_message({"role": self.role, "content": self._aggregation})
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
@@ -290,31 +203,236 @@ class LLMContextAggregator(LLMResponseAggregator):
await self.push_frame(frame)
# Reset our accumulator state.
self._reset()
self.reset()
class LLMAssistantContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
super().__init__(
messages=[],
context=context,
role="assistant",
start_frame=LLMFullResponseStartFrame,
end_frame=LLMFullResponseEndFrame,
accumulator_frame=TextFrame,
handle_interruptions=True,
expect_stripped_words=expect_stripped_words,
)
class LLMUserContextAggregator(LLMContextResponseAggregator):
"""This is a user LLM aggregator that uses an LLM context to store the
conversation. It aggregates transcriptions from the STT service and it has
logic to handle multiple scenarios where transcriptions are received between
VAD events (`UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame`) or
even outside or no VAD events at all.
"""
def __init__(
self,
context: OpenAILLMContext,
aggregation_timeout: float = 1.0,
bot_interruption_timeout: float = 2.0,
**kwargs,
):
super().__init__(context=context, role="user", **kwargs)
self._aggregation_timeout = aggregation_timeout
self._bot_interruption_timeout = bot_interruption_timeout
self._seen_interim_results = False
self._user_speaking = False
self._last_user_speaking_time = 0
self._emulating_vad = False
self._aggregation_event = asyncio.Event()
self._aggregation_task = None
self.reset()
def reset(self):
super().reset()
self._seen_interim_results = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self._start(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, EndFrame):
await self._stop(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, CancelFrame):
await self._cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStoppedSpeakingFrame):
await self._handle_user_stopped_speaking(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, TranscriptionFrame):
await self._handle_transcription(frame)
elif isinstance(frame, InterimTranscriptionFrame):
await self._handle_interim_transcription(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
self.add_messages(frame.messages)
elif isinstance(frame, LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
else:
await self.push_frame(frame, direction)
async def _start(self, frame: StartFrame):
self._create_aggregation_task()
async def _stop(self, frame: EndFrame):
await self._cancel_aggregation_task()
async def _cancel(self, frame: CancelFrame):
await self._cancel_aggregation_task()
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
self._user_speaking = True
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
self._last_user_speaking_time = time.time()
self._user_speaking = False
if not self._seen_interim_results:
await self.push_aggregation()
async def _handle_transcription(self, frame: TranscriptionFrame):
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
# We just got a final result, so let's reset interim results.
self._seen_interim_results = False
# Reset aggregation timer.
self._aggregation_event.set()
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
self._seen_interim_results = True
# Reset aggregation timer.
self._aggregation_event.set()
def _create_aggregation_task(self):
self._aggregation_task = self.create_task(self._aggregation_task_handler())
async def _cancel_aggregation_task(self):
if self._aggregation_task:
await self.cancel_task(self._aggregation_task)
self._aggregation_task = None
async def _aggregation_task_handler(self):
while True:
try:
await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout)
await self._maybe_push_bot_interruption()
except asyncio.TimeoutError:
if not self._user_speaking:
await self.push_aggregation()
# If we are emulating VAD we still need to send the user stopped
# speaking frame.
if self._emulating_vad:
await self.push_frame(
EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM
)
self._emulating_vad = False
finally:
self._aggregation_event.clear()
async def _maybe_push_bot_interruption(self):
"""If the user stopped speaking a while back and we got a transcription
frame we might want to interrupt the bot.
"""
if not self._user_speaking:
diff_time = time.time() - self._last_user_speaking_time
if diff_time > self._bot_interruption_timeout:
# If we reach this case we received a transcription but VAD was
# not able to detect voice (e.g. when you whisper a short
# utterance). So, we need to emulate VAD (i.e. user
# start/stopped speaking).
await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM)
self._emulating_vad = True
# Reset time so we don't interrupt again right away.
self._last_user_speaking_time = time.time()
class LLMUserContextAggregator(LLMContextAggregator):
def __init__(self, context: OpenAILLMContext):
super().__init__(
messages=[],
context=context,
role="user",
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
)
class LLMAssistantContextAggregator(LLMContextResponseAggregator):
"""This is an assistant LLM aggregator that uses an LLM context to store the
conversation. It aggregates text frames received between
`LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`.
"""
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs):
super().__init__(context=context, role="assistant", **kwargs)
self._expect_stripped_words = expect_stripped_words
self.reset()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self.push_aggregation()
# Reset anyways
self.reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseStartFrame):
await self._handle_llm_start(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame)
elif isinstance(frame, TextFrame):
await self._handle_text(frame)
elif isinstance(frame, LLMMessagesAppendFrame):
self.add_messages(frame.messages)
elif isinstance(frame, LLMMessagesUpdateFrame):
self.set_messages(frame.messages)
elif isinstance(frame, LLMSetToolsFrame):
self.set_tools(frame.tools)
else:
await self.push_frame(frame, direction)
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
self._started = True
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
self._started = False
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
return
if self._expect_stripped_words:
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
else:
self._aggregation += frame.text
class LLMUserResponseAggregator(LLMUserContextAggregator):
def __init__(self, messages: List[dict] = [], **kwargs):
super().__init__(context=OpenAILLMContext(messages), **kwargs)
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message({"role": self.role, "content": self._aggregation})
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
frame = LLMMessagesFrame(self._context.messages)
await self.push_frame(frame)
# Reset our accumulator state.
self.reset()
class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
def __init__(self, messages: List[dict], **kwargs):
super().__init__(context=OpenAILLMContext(messages), **kwargs)
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message({"role": self.role, "content": self._aggregation})
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
frame = LLMMessagesFrame(self._context.messages)
await self.push_frame(frame)
# Reset our accumulator state.
self.reset()

View File

@@ -4,131 +4,15 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Optional
from pipecat.frames.frames import (
Frame,
InterimTranscriptionFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.frames.frames import TextFrame
from pipecat.processors.aggregators.llm_response import LLMUserResponseAggregator
class ResponseAggregator(FrameProcessor):
"""This frame processor aggregates frames between a start and an end frame
into complete text frame sentences.
class UserResponseAggregator(LLMUserResponseAggregator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
For example, frame input/output:
UserStartedSpeakingFrame() -> None
TranscriptionFrame("Hello,") -> None
TranscriptionFrame(" world.") -> None
UserStoppedSpeakingFrame() -> TextFrame("Hello world.")
Doctest: FIXME to work with asyncio
>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)
>>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame,
... end_frame=UserStoppedSpeakingFrame,
... accumulator_frame=TranscriptionFrame,
... pass_through=False)
>>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame()))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1)))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2)))
>>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame()))
Hello, world.
"""
def __init__(
self,
*,
start_frame,
end_frame,
accumulator_frame: TextFrame,
interim_accumulator_frame: Optional[TextFrame] = None,
):
super().__init__()
self._start_frame = start_frame
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
# Reset our accumulator state.
self._reset()
#
# Frame processor
#
# Use cases implemented:
#
# S: Start, E: End, T: Transcription, I: Interim, X: Text
#
# S E -> None
# S T E -> X
# S I T E -> X
# S I E T -> X
# S I E I T -> X
# S E T -> X
# S E I T -> X
#
# The following case would not be supported:
#
# S I E T1 I T2 -> X
#
# and T2 would be dropped.
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
send_aggregation = False
if isinstance(frame, self._start_frame):
self._aggregating = True
self._seen_start_frame = True
self._seen_end_frame = False
self._seen_interim_results = False
await self.push_frame(frame, direction)
elif isinstance(frame, self._end_frame):
self._seen_end_frame = True
self._seen_start_frame = False
# We might have received the end frame but we might still be
# aggregating (i.e. we have seen interim results but not the final
# text).
self._aggregating = self._seen_interim_results or len(self._aggregation) == 0
# Send the aggregation if we are not aggregating anymore (i.e. no
# more interim results received).
send_aggregation = not self._aggregating
await self.push_frame(frame, direction)
elif isinstance(frame, self._accumulator_frame):
if self._aggregating:
self._aggregation += f" {frame.text}"
# We have recevied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
send_aggregation = self._seen_end_frame
# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
else:
await self.push_frame(frame, direction)
if send_aggregation:
await self._push_aggregation()
async def _push_aggregation(self):
async def push_aggregation(self):
if len(self._aggregation) > 0:
frame = TextFrame(self._aggregation.strip())
@@ -139,21 +23,4 @@ class ResponseAggregator(FrameProcessor):
await self.push_frame(frame)
# Reset our accumulator state.
self._reset()
def _reset(self):
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
class UserResponseAggregator(ResponseAggregator):
def __init__(self):
super().__init__(
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
)
self.reset()

View File

@@ -577,7 +577,7 @@ class SegmentedSTTService(STTService):
self._smoothing_factor = 0.2
self._prev_volume = 0
async def process_audio_frame(self, frame: AudioRawFrame):
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
# Try to filter out empty background noise
volume = self._get_smoothed_volume(frame)
if volume >= self._min_volume:

View File

@@ -126,9 +126,11 @@ class AnthropicLLMService(LLMService):
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> AnthropicContextAggregatorPair:
if isinstance(context, OpenAILLMContext):
context = AnthropicLLMContext.from_openai_context(context)
user = AnthropicUserContextAggregator(context)
assistant = AnthropicAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
context, expect_stripped_words=assistant_expect_stripped_words
)
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
@@ -651,11 +653,8 @@ class AnthropicLLMContext(OpenAILLMContext):
class AnthropicUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext):
super().__init__(context=context)
if isinstance(context, OpenAILLMContext):
self._context = AnthropicLLMContext.from_openai_context(context)
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs):
super().__init__(context=context, **kwargs)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
@@ -703,9 +702,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs):
super().__init__(context=context, **kwargs)
self._function_call_in_progress = None
self._function_call_result = None
self._pending_image_frame_message = None
@@ -725,7 +723,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
):
self._function_call_in_progress = None
self._function_call_result = frame
await self._push_aggregation()
await self.push_aggregation()
else:
logger.warning(
"FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id"
@@ -734,9 +732,9 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
self._function_call_result = None
elif isinstance(frame, AnthropicImageMessageFrame):
self._pending_image_frame_message = frame
await self._push_aggregation()
await self.push_aggregation()
async def _push_aggregation(self):
async def push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
@@ -746,7 +744,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
properties: Optional[FunctionCallResultProperties] = None
aggregation = self._aggregation
self._reset()
self.reset()
try:
if self._function_call_result:
@@ -799,15 +797,14 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:
await properties.on_context_updated()
# Push context frame
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
await self.push_context_frame()
# Push timestamp frame with current time
timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())

View File

@@ -115,10 +115,10 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def _push_aggregation(self):
async def push_aggregation(self):
# We don't want to store any images in the context. Revisit this later when the API evolves.
self._pending_image_frame_message = None
await super()._push_aggregation()
await super().push_aggregation()
@dataclass
@@ -706,6 +706,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
GeminiMultimodalLiveContext.upgrade(context)
user = GeminiMultimodalLiveUserContextAggregator(context)
assistant = GeminiMultimodalLiveAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
context, expect_stripped_words=assistant_expect_stripped_words
)
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -537,7 +537,7 @@ def language_to_google_stt_language(language: Language) -> Optional[str]:
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
async def _push_aggregation(self):
async def push_aggregation(self):
if len(self._aggregation) > 0:
self._context.add_message(
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
@@ -552,11 +552,11 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
await self.push_frame(frame)
# Reset our accumulator state.
self._reset()
self.reset()
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def _push_aggregation(self):
async def push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
@@ -566,7 +566,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
properties: Optional[FunctionCallResultProperties] = None
aggregation = self._aggregation
self._reset()
self.reset()
try:
if self._function_call_result:
@@ -626,15 +626,14 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:
await properties.on_context_updated()
# Push context frame
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
await self.push_context_frame()
# Push timestamp frame with current time
timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())
@@ -1175,7 +1174,7 @@ class GoogleLLMService(LLMService):
) -> GoogleContextAggregatorPair:
user = GoogleUserContextAggregator(context)
assistant = GoogleAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
context, expect_stripped_words=assistant_expect_stripped_words
)
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -17,6 +17,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai import (
OpenAIAssistantContextAggregator,
OpenAILLMService,
@@ -27,7 +28,7 @@ from pipecat.services.openai import (
class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
"""Custom assistant context aggregator for Grok that handles empty content requirement."""
async def _push_aggregation(self):
async def push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
@@ -37,7 +38,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
properties: Optional[FunctionCallResultProperties] = None
aggregation = self._aggregation
self._reset()
self.reset()
try:
if self._function_call_result:
@@ -91,14 +92,13 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:
await properties.on_context_updated()
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
await self.push_context_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")
@@ -212,6 +212,6 @@ class GrokLLMService(OpenAILLMService):
) -> GrokContextAggregatorPair:
user = OpenAIUserContextAggregator(context)
assistant = GrokAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
context, expect_stripped_words=assistant_expect_stripped_words
)
return GrokContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -355,7 +355,7 @@ class OpenAILLMService(BaseOpenAILLMService):
) -> OpenAIContextAggregatorPair:
user = OpenAIUserContextAggregator(context)
assistant = OpenAIAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
context, expect_stripped_words=assistant_expect_stripped_words
)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
@@ -555,8 +555,8 @@ class OpenAIImageMessageFrame(Frame):
class OpenAIUserContextAggregator(LLMUserContextAggregator):
def __init__(self, context: OpenAILLMContext):
super().__init__(context=context)
def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(context=context, **kwargs)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
@@ -592,9 +592,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
def __init__(self, context: OpenAILLMContext, **kwargs):
super().__init__(context=context, **kwargs)
self._function_calls_in_progress = {}
self._function_call_result = None
self._pending_image_frame_message = None
@@ -614,7 +613,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
del self._function_calls_in_progress[frame.tool_call_id]
self._function_call_result = frame
# TODO-CB: Kwin wants us to refactor this out of here but I REFUSE
await self._push_aggregation()
await self.push_aggregation()
else:
logger.warning(
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
@@ -622,9 +621,9 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
self._function_call_result = None
elif isinstance(frame, OpenAIImageMessageFrame):
self._pending_image_frame_message = frame
await self._push_aggregation()
await self.push_aggregation()
async def _push_aggregation(self):
async def push_aggregation(self):
if not (
self._aggregation or self._function_call_result or self._pending_image_frame_message
):
@@ -634,7 +633,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
properties: Optional[FunctionCallResultProperties] = None
aggregation = self._aggregation
self._reset()
self.reset()
try:
if self._function_call_result:
@@ -686,15 +685,14 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
run_llm = True
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:
await properties.on_context_updated()
# Push context frame
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
await self.push_context_frame()
# Push timestamp frame with current time
timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601())

View File

@@ -166,7 +166,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
if isinstance(frame, LLMSetToolsFrame):
await self.push_frame(frame, direction)
async def _push_aggregation(self):
async def push_aggregation(self):
# for the moment, ignore all user input coming into the pipeline.
# todo: think about whether/how to fix this to allow for text input from
# upstream (transport/transcription, or other sources)
@@ -174,7 +174,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
async def _push_aggregation(self):
async def push_aggregation(self):
# the only thing we implement here is function calling. in all other cases, messages
# are added to the context when we receive openai realtime api events
if not self._function_call_result:
@@ -182,7 +182,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
properties: Optional[FunctionCallResultProperties] = None
self._reset()
self.reset()
try:
run_llm = True
frame = self._function_call_result
@@ -217,8 +217,8 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
# special frame to do that.
await self._user_context_aggregator.push_frame(
RealtimeFunctionCallResultFrame(result_frame=frame)
await self.push_frame(
RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
)
if properties and properties.run_llm is not None:
# If the tool call result has a run_llm property, use it
@@ -228,14 +228,13 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
run_llm = not bool(self._function_calls_in_progress)
if run_llm:
await self._user_context_aggregator.push_context_frame()
await self.push_context_frame(FrameDirection.UPSTREAM)
# Emit the on_context_updated callback once the function call result is added to the context
if properties and properties.on_context_updated is not None:
await properties.on_context_updated()
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
await self.push_context_frame()
except Exception as e:
logger.error(f"Error processing frame: {e}")

View File

@@ -568,6 +568,6 @@ class OpenAIRealtimeBetaLLMService(LLMService):
OpenAIRealtimeLLMContext.upgrade_to_realtime(context)
user = OpenAIRealtimeUserContextAggregator(context)
assistant = OpenAIRealtimeAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
context, expect_stripped_words=assistant_expect_stripped_words
)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -5,6 +5,7 @@
#
import asyncio
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Sequence, Tuple
from pipecat.frames.frames import (
@@ -12,6 +13,7 @@ from pipecat.frames.frames import (
Frame,
HeartbeatFrame,
StartFrame,
SystemFrame,
)
from pipecat.observers.base_observer import BaseObserver
from pipecat.pipeline.pipeline import Pipeline
@@ -20,6 +22,15 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
@dataclass
class SleepFrame(SystemFrame):
"""This frame is used by test framework to introduce some sleep time before
the next frame is pushed. This is useful to control system frames vs data or
control frames."""
sleep: float = 0.1
class HeartbeatsObserver(BaseObserver):
def __init__(
self,
@@ -44,7 +55,11 @@ class HeartbeatsObserver(BaseObserver):
class QueuedFrameProcessor(FrameProcessor):
def __init__(
self, queue: asyncio.Queue, queue_direction: FrameDirection, ignore_start: bool = True
self,
*,
queue: asyncio.Queue,
queue_direction: FrameDirection,
ignore_start: bool = True,
):
super().__init__()
self._queue = queue
@@ -72,21 +87,35 @@ async def run_test(
) -> Tuple[Sequence[Frame], Sequence[Frame]]:
received_up = asyncio.Queue()
received_down = asyncio.Queue()
source = QueuedFrameProcessor(received_up, FrameDirection.UPSTREAM, ignore_start)
sink = QueuedFrameProcessor(received_down, FrameDirection.DOWNSTREAM, ignore_start)
source = QueuedFrameProcessor(
queue=received_up,
queue_direction=FrameDirection.UPSTREAM,
ignore_start=ignore_start,
)
sink = QueuedFrameProcessor(
queue=received_down,
queue_direction=FrameDirection.DOWNSTREAM,
ignore_start=ignore_start,
)
pipeline = Pipeline([source, processor, sink])
task = PipelineTask(pipeline, params=PipelineParams(start_metadata=start_metadata))
for frame in frames_to_send:
await task.queue_frame(frame)
async def push_frames():
# Just give a little head start to the runner.
await asyncio.sleep(0.01)
for frame in frames_to_send:
if isinstance(frame, SleepFrame):
await asyncio.sleep(frame.sleep)
else:
await task.queue_frame(frame)
if send_end_frame:
await task.queue_frame(EndFrame())
if send_end_frame:
await task.queue_frame(EndFrame())
runner = PipelineRunner()
await runner.run(task)
await asyncio.gather(runner.run(task), push_frames())
#
# Down frames
@@ -98,6 +127,7 @@ async def run_test(
received_down_frames.append(frame)
print("received DOWN frames =", received_down_frames)
print("expected DOWN frames =", expected_down_frames)
assert len(received_down_frames) == len(expected_down_frames)
@@ -113,6 +143,7 @@ async def run_test(
received_up_frames.append(frame)
print("received UP frames =", received_up_frames)
print("expected UP frames =", expected_up_frames)
assert len(received_up_frames) == len(expected_up_frames)

View File

@@ -14,6 +14,8 @@ from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
EndFrame,
FilterUpdateSettingsFrame,
Frame,
@@ -112,9 +114,13 @@ class BaseInputTransport(FrameProcessor):
await self.cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotInterruptionFrame):
logger.debug("Bot interruption")
await self._start_interruption()
await self.push_frame(StartInterruptionFrame())
await self._handle_bot_interruption(frame)
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
logger.debug("Emulating user started speaking")
await self._handle_user_interruption(UserStartedSpeakingFrame())
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
logger.debug("Emulating user stopped speaking")
await self._handle_user_interruption(UserStoppedSpeakingFrame())
# All other system frames
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
@@ -137,7 +143,13 @@ class BaseInputTransport(FrameProcessor):
# Handle interruptions
#
async def _handle_interruptions(self, frame: Frame):
async def _handle_bot_interruption(self, frame: BotInterruptionFrame):
logger.debug("Bot interruption")
if self.interruptions_allowed:
await self._start_interruption()
await self.push_frame(StartInterruptionFrame())
async def _handle_user_interruption(self, frame: Frame):
if isinstance(frame, UserStartedSpeakingFrame):
logger.debug("User started speaking")
# Make sure we notify about interruptions quickly out-of-band.
@@ -183,7 +195,7 @@ class BaseInputTransport(FrameProcessor):
frame = UserStoppedSpeakingFrame()
if frame:
await self._handle_interruptions(frame)
await self._handle_user_interruption(frame)
vad_state = new_vad_state
return vad_state

View File

@@ -170,6 +170,8 @@ class BaseOutputTransport(FrameProcessor):
# TODO(aleix): Images and audio should support presentation timestamps.
elif frame.pts:
await self._sink_clock_queue.put((frame.pts, frame.id, frame))
elif direction == FrameDirection.UPSTREAM:
await self.push_frame(frame, direction)
else:
await self._sink_queue.put(frame)

View File

@@ -7,8 +7,10 @@ deepgram-sdk~=3.5.0
fal-client~=0.4.1
fastapi~=0.115.0
faster-whisper~=1.0.3
google-cloud-texttospeech~=2.21.1
google-generativeai~=0.8.3
google-cloud-speech~=2.31.0
google-cloud-texttospeech~=2.25.0
google-genai~=1.2.0
google-generativeai~=0.8.4
langchain~=0.2.14
livekit~=0.13.1
lmnt~=1.1.4

View File

@@ -29,12 +29,12 @@ class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase):
for word in sentence.split(" "):
frames_to_send.append(TextFrame(text=word + " "))
expected_returned_frames = [TextFrame, TextFrame, TextFrame]
expected_down_frames = [TextFrame, TextFrame, TextFrame]
(received_down, _) = await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
assert received_down[-3].text == "Hello, world. "
assert received_down[-2].text == "How are you? "
@@ -59,7 +59,7 @@ class TestGatedAggregator(unittest.IsolatedAsyncioTestCase):
LLMFullResponseEndFrame(),
]
expected_returned_frames = [
expected_down_frames = [
OutputImageRawFrame,
LLMFullResponseStartFrame,
TextFrame,
@@ -72,5 +72,5 @@ class TestGatedAggregator(unittest.IsolatedAsyncioTestCase):
(received_down, _) = await run_test(
gated_aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)

View File

@@ -0,0 +1,671 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
import google.ai.generativelanguage as glm
from pipecat.frames.frames import (
EmulateUserStartedSpeakingFrame,
EmulateUserStoppedSpeakingFrame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
OpenAILLMContextAssistantTimestampFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.services.anthropic import (
AnthropicAssistantContextAggregator,
AnthropicLLMContext,
AnthropicUserContextAggregator,
)
from pipecat.services.google.google import (
GoogleAssistantContextAggregator,
GoogleLLMContext,
GoogleUserContextAggregator,
)
from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator
from pipecat.tests.utils import SleepFrame, run_test
AGGREGATION_TIMEOUT = 0.1
AGGREGATION_SLEEP = 0.15
BOT_INTERRUPTION_TIMEOUT = 0.2
BOT_INTERRUPTION_SLEEP = 0.25
class BaseTestUserContextAggregator:
CONTEXT_CLASS = None # To be set in subclasses
AGGREGATOR_CLASS = None # To be set in subclasses
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
assert context.messages[index]["content"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
assert context.messages[index]["content"] == content
async def test_se(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()]
expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
async def test_ste(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "Hello!")
async def test_site(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "Hello Pipecat!")
async def test_st1iest2e(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
UserStartedSpeakingFrame(),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "Hello Pipecat! How are you?")
async def test_siet(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "How are you?")
async def test_sieit(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "How are you?")
async def test_set(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
UserStartedSpeakingFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "How are you?")
async def test_seit(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
UserStartedSpeakingFrame(),
UserStoppedSpeakingFrame(),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "How are you?")
async def test_st1et2(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_multi_content(context, 0, 0, "Hello Pipecat!")
self.check_message_multi_content(context, 0, 1, "How are you?")
async def test_set1t2(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
UserStartedSpeakingFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "Hello Pipecat! How are you?")
async def test_siet1it2(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "Hello Pipecat! How are you?")
async def test_t(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
self.check_message_content(context, 0, "Hello!")
async def test_it(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT)
frames_to_send = [
InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""),
SleepFrame(),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
self.check_message_content(context, 0, "Hello Pipecat!")
async def test_sie_delay_it(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(
context,
aggregation_timeout=AGGREGATION_TIMEOUT,
bot_interruption_timeout=BOT_INTERRUPTION_TIMEOUT,
)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
SleepFrame(BOT_INTERRUPTION_SLEEP),
InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""),
TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
expected_up_frames=expected_up_frames,
)
self.check_message_content(context, 0, "How are you?")
class BaseTestAssistantContextAggreagator:
CONTEXT_CLASS = None # To be set in subclasses
AGGREGATOR_CLASS = None # To be set in subclasses
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
assert context.messages[index]["content"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
assert context.messages[index]["content"] == content
async def test_empty(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
expected_down_frames = []
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
async def test_single_text(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello Pipecat!"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "Hello Pipecat!")
async def test_multiple_text(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),
TextFrame(text="Pipecat. "),
TextFrame(text="How are "),
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "Hello Pipecat. How are you?")
async def test_multiple_text_stripped(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello"),
TextFrame(text="Pipecat."),
TextFrame(text="How are"),
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(context, 0, "Hello Pipecat. How are you?")
async def test_multiple_llm_responses(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),
TextFrame(text="Pipecat."),
LLMFullResponseEndFrame(),
LLMFullResponseStartFrame(),
TextFrame(text="How are "),
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_multi_content(context, 0, 0, "Hello Pipecat.")
self.check_message_multi_content(context, 0, 1, "How are you?")
async def test_multiple_llm_responses_interruption(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello "),
TextFrame(text="Pipecat."),
LLMFullResponseEndFrame(),
SleepFrame(AGGREGATION_SLEEP),
StartInterruptionFrame(),
LLMFullResponseStartFrame(),
TextFrame(text="How are "),
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [
*self.EXPECTED_CONTEXT_FRAMES,
StartInterruptionFrame,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_multi_content(context, 0, 0, "Hello Pipecat.")
self.check_message_multi_content(context, 0, 1, "How are you?")
#
# LLMUserContextAggregator, LLMAssistantContextAggregator
#
class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = LLMUserContextAggregator
class TestLLMAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = LLMAssistantContextAggregator
#
# OpenAI
#
class TestOpenAIUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIUserContextAggregator
class TestOpenAIAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
#
# Anthropic
#
class TestAnthropicUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AnthropicLLMContext
AGGREGATOR_CLASS = AnthropicUserContextAggregator
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
messages = context.messages[content_index]
assert messages["content"][index]["text"] == content
class TestAnthropicAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AnthropicLLMContext
AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
messages = context.messages[content_index]
assert messages["content"][index]["text"] == content
#
# Google
#
class TestGoogleUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = GoogleLLMContext
AGGREGATOR_CLASS = GoogleUserContextAggregator
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = glm.Content.to_dict(context.messages[index])
assert obj["parts"][0]["text"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
obj = glm.Content.to_dict(context.messages[index])
assert obj["parts"][0]["text"] == content
class TestGoogleAssistantContextAggregator(
BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = GoogleLLMContext
AGGREGATOR_CLASS = GoogleAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = glm.Content.to_dict(context.messages[index])
assert obj["parts"][0]["text"] == content
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
):
obj = glm.Content.to_dict(context.messages[index])
assert obj["parts"][0]["text"] == content

View File

@@ -25,11 +25,11 @@ class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase):
async def test_identity(self):
filter = IdentityFilter()
frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()]
expected_returned_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
@@ -37,32 +37,32 @@ class TestFrameFilter(unittest.IsolatedAsyncioTestCase):
async def test_text_frame(self):
filter = FrameFilter(types=(TextFrame,))
frames_to_send = [TextFrame(text="Hello Pipecat!")]
expected_returned_frames = [TextFrame]
expected_down_frames = [TextFrame]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
async def test_end_frame(self):
filter = FrameFilter(types=(EndFrame,))
frames_to_send = [EndFrame()]
expected_returned_frames = [EndFrame]
expected_down_frames = [EndFrame]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
send_end_frame=False,
)
async def test_system_frame(self):
filter = FrameFilter(types=())
frames_to_send = [UserStartedSpeakingFrame()]
expected_returned_frames = [UserStartedSpeakingFrame]
expected_down_frames = [UserStartedSpeakingFrame]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
@@ -73,11 +73,11 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase):
filter = FunctionFilter(filter=passthrough)
frames_to_send = [TextFrame(text="Hello Pipecat!")]
expected_returned_frames = [TextFrame]
expected_down_frames = [TextFrame]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
async def test_no_passthrough(self):
@@ -86,11 +86,11 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase):
filter = FunctionFilter(filter=no_passthrough)
frames_to_send = [TextFrame(text="Hello Pipecat!")]
expected_returned_frames = []
expected_down_frames = []
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
@@ -98,11 +98,11 @@ class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase):
async def test_no_wake_word(self):
filter = WakeCheckFilter(wake_phrases=["Hey, Pipecat"])
frames_to_send = [TranscriptionFrame(user_id="test", text="Phrase 1", timestamp="")]
expected_returned_frames = []
expected_down_frames = []
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
async def test_wake_word(self):
@@ -111,10 +111,10 @@ class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase):
TranscriptionFrame(user_id="test", text="Hey, Pipecat", timestamp=""),
TranscriptionFrame(user_id="test", text="Phrase 1", timestamp=""),
]
expected_returned_frames = [TranscriptionFrame, TranscriptionFrame]
expected_down_frames = [TranscriptionFrame, TranscriptionFrame]
(received_down, _) = await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
assert received_down[-1].text == "Phrase 1"

View File

@@ -10,23 +10,22 @@ from langchain.prompts import ChatPromptTemplate
from langchain_core.language_models import FakeStreamingListLLM
from pipecat.frames.frames import (
EndFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.processors.frameworks.langchain import LangchainProcessor
from pipecat.tests.utils import SleepFrame, run_test
class TestLangchain(unittest.IsolatedAsyncioTestCase):
@@ -64,31 +63,26 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
self.mock_proc = self.MockProcessor("token_collector")
tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages, expect_stripped_words=False)
pipeline = Pipeline(
[
tma_in,
proc,
self.mock_proc,
tma_out,
]
pipeline = Pipeline([tma_in, proc, self.mock_proc, tma_out])
frames_to_send = [
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
SleepFrame(),
UserStoppedSpeakingFrame(),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
LLMMessagesFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=False))
await task.queue_frames(
[
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
UserStoppedSpeakingFrame(),
EndFrame(),
]
)
runner = PipelineRunner()
await runner.run(task)
self.assertEqual("".join(self.mock_proc.token), self.expected_response)
# TODO: Address this issue
# This next one would fail with:
# AssertionError: ' H e l l o d e a r h u m a n' != 'Hello dear human'
# self.assertEqual(tma_out.messages[-1]["content"], self.expected_response)
self.assertEqual(tma_out.messages[-1]["content"], self.expected_response)

View File

@@ -21,11 +21,11 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase):
pipeline = Pipeline([IdentityFilter()])
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
expected_returned_frames = [TextFrame]
expected_down_frames = [TextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
async def test_pipeline_multiple(self):
@@ -36,22 +36,22 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase):
pipeline = Pipeline([identity1, identity2, identity3])
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
expected_returned_frames = [TextFrame]
expected_down_frames = [TextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
async def test_pipeline_start_metadata(self):
pipeline = Pipeline([IdentityFilter()])
frames_to_send = []
expected_returned_frames = [StartFrame]
expected_down_frames = [StartFrame]
(received_down, _) = await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
ignore_start=False,
start_metadata={"foo": "bar"},
)
@@ -63,11 +63,11 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase):
pipeline = ParallelPipeline([IdentityFilter()])
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
expected_returned_frames = [TextFrame]
expected_down_frames = [TextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)
async def test_parallel_multiple(self):
@@ -75,11 +75,11 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase):
pipeline = ParallelPipeline([IdentityFilter()], [IdentityFilter()])
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
expected_returned_frames = [TextFrame]
expected_down_frames = [TextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
expected_down_frames=expected_down_frames,
)