diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f9cd7e0c..ca7d08e3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,14 +9,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added new frames `EmulateUserStartedSpeakingFrame` and + `EmulateUserStoppedSpeakingFrame` which can be used to emulated VAD behavior + without VAD being present or not being triggered. + - Added a new `audio_in_stream_on_start` field to `TransportParams`. - -- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`. - - This method should be used to start receiving the input audio in case the field `audio_in_stream_on_start` is set to `false`. -- Added support for the `RTVIProcessor` to handle buffered audio in `base64` format, converting it into InputAudioRawFrame for transport. +- Added a new method `start_audio_in_streaming` in the `BaseInputTransport`. -- Added support for the `RTVIProcessor` to trigger `start_audio_in_streaming` only after the `client-ready` message. + - This method should be used to start receiving the input audio in case the + field `audio_in_stream_on_start` is set to `false`. + +- Added support for the `RTVIProcessor` to handle buffered audio in `base64` + format, converting it into InputAudioRawFrame for transport. + +- Added support for the `RTVIProcessor` to trigger `start_audio_in_streaming` + only after the `client-ready` message. - Added new `MUTE_UNTIL_FIRST_BOT_COMPLETE` strategy to `STTMuteStrategy`. This strategy starts muted and remains muted until the first bot speech completes, @@ -38,18 +46,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 OpenAI-compatible interface. Also, added foundational example `14n-function-calling-perplexity.py`. -- Added `DailyTransport.update_remote_participants()`. This allows you to - update remote participant's settings, like their permissions or which of - their devices are enabled. Requires that the local participant have - participant admin permission. +- Added `DailyTransport.update_remote_participants()`. This allows you to update + remote participant's settings, like their permissions or which of their + devices are enabled. Requires that the local participant have participant + admin permission. ### Changed -- Updated `DailyTransport` to respect the `audio_in_stream_on_start` field, ensuring it only starts receiving the audio input if it is enabled. +- Updated `DailyTransport` to respect the `audio_in_stream_on_start` field, + ensuring it only starts receiving the audio input if it is enabled. -- Updated `FastAPIWebsocketOutputTransport` to send `TransportMessageFrame` and `TransportMessageUrgentFrame` to the serializer. +- Updated `FastAPIWebsocketOutputTransport` to send `TransportMessageFrame` and + `TransportMessageUrgentFrame` to the serializer. -- Updated `WebsocketServerOutputTransport` to send `TransportMessageFrame` and `TransportMessageUrgentFrame` to the serializer. +- Updated `WebsocketServerOutputTransport` to send `TransportMessageFrame` and + `TransportMessageUrgentFrame` to the serializer. - Enhanced `STTMuteConfig` to validate strategy combinations, preventing `MUTE_UNTIL_FIRST_BOT_COMPLETE` and `FIRST_SPEECH` from being used together @@ -91,6 +102,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed a `BaseOutputTransport` issue that was causing upstream frames to no be + pushed upstream. + +- Fixed multiple issue where user transcriptions where not being handled + properly. It was possible for short utterances to not trigger VAD which would + cause user transcriptions to be ignored. It was also possible for one or more + transcriptions to be generated after VAD in which case they would also be + ignored. + - Fixed an issue that was causing `BotStoppedSpeakingFrame` to be generated too late. This could then cause issues unblocking `STTMuteFilter` later than desired. @@ -283,7 +303,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `enable_recording` and `geo` parameters to `DailyRoomProperties`. -- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings to a custom AWS bucket. +- Added `RecordingsBucketConfig` to `DailyRoomProperties` to upload recordings + to a custom AWS bucket. ### Changed diff --git a/examples/foundational/22d-natural-conversation-gemini-audio.py b/examples/foundational/22d-natural-conversation-gemini-audio.py index 860f1f574..15ee835e2 100644 --- a/examples/foundational/22d-natural-conversation-gemini-audio.py +++ b/examples/foundational/22d-natural-conversation-gemini-audio.py @@ -497,7 +497,7 @@ class UserAggregatorBuffer(LLMResponseAggregator): if isinstance(frame, UserStartedSpeakingFrame): self._transcription = "" - async def _push_aggregation(self): + async def push_aggregation(self): if self._aggregation: self._transcription = self._aggregation self._aggregation = "" diff --git a/pyproject.toml b/pyproject.toml index 0578c9622..3c5d6262b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ elevenlabs = [ "websockets~=13.1" ] fal = [ "fal-client~=0.5.6" ] fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ] gladia = [ "websockets~=13.1" ] -google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.24.0", "google-genai~=1.0.0", "google-cloud-speech~=2.30.0" ] +google = [ "google-cloud-speech~=2.31.0", "google-cloud-texttospeech~=2.25.0", "google-genai~=1.2.0", "google-generativeai~=0.8.4" ] grok = [ "openai~=1.59.6" ] groq = [ "openai~=1.59.6" ] gstreamer = [ "pygobject~=3.50.0" ] diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 09d0a93b0..c9b812a1c 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -565,6 +565,22 @@ class UserStoppedSpeakingFrame(SystemFrame): pass +@dataclass +class EmulateUserStartedSpeakingFrame(SystemFrame): + """Emitted by internal processors upstream to emulate VAD behavior when a + user starts speaking.""" + + pass + + +@dataclass +class EmulateUserStoppedSpeakingFrame(SystemFrame): + """Emitted by internal processors upstream to emulate VAD behavior when a + user stops speaking.""" + + pass + + @dataclass class BotInterruptionFrame(SystemFrame): """Emitted by when the bot should be interrupted. This will mainly cause the diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 78db72dfb..950c155e6 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -4,9 +4,17 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import List, Optional, Type +import asyncio +import time +from abc import abstractmethod +from typing import List from pipecat.frames.frames import ( + BotInterruptionFrame, + CancelFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, + EndFrame, Frame, InterimTranscriptionFrame, LLMFullResponseEndFrame, @@ -15,6 +23,7 @@ from pipecat.frames.frames import ( LLMMessagesFrame, LLMMessagesUpdateFrame, LLMSetToolsFrame, + StartFrame, StartInterruptionFrame, TextFrame, TranscriptionFrame, @@ -28,121 +37,105 @@ from pipecat.processors.aggregators.openai_llm_context import ( from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -class LLMResponseAggregator(FrameProcessor): +class BaseLLMResponseAggregator(FrameProcessor): + """This is the base class for all LLM response aggregators. These + aggregators process incoming frames and aggregate content until they are + ready to push the aggregation. In the case of a user, an aggregation might + be a full transcription received from the STT service. + + The LLM response aggregators also keep a store (e.g. a message list or an + LLM context) of the current conversation, that is, it stores the messages + said by the user or by the bot. + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @property + @abstractmethod + def messages(self) -> List[dict]: + """Returns the messages from the current conversation.""" + pass + + @property + @abstractmethod + def role(self) -> str: + """Returns the role (e.g. user, assistant...) for this aggregator.""" + pass + + @abstractmethod + def add_messages(self, messages): + """Add the given messages to the conversation.""" + pass + + @abstractmethod + def set_messages(self, messages): + """Reset the conversation with the given messages.""" + pass + + @abstractmethod + def set_tools(self, tools): + """Set LLM tools to be used in the current conversation.""" + pass + + @abstractmethod + def reset(self): + """Reset the internals of this aggregator. This should not modify the + internal messages.""" + pass + + @abstractmethod + async def push_aggregation(self): + pass + + +class LLMResponseAggregator(BaseLLMResponseAggregator): + """This is a base LLM aggregator that uses a simple list of messages to + store the conversation. It pushes `LLMMessagesFrame` as an aggregation + frame. + + """ + def __init__( self, *, messages: List[dict], - role: str, - start_frame, - end_frame, - accumulator_frame: Type[TextFrame], - interim_accumulator_frame: Optional[Type[TextFrame]] = None, - handle_interruptions: bool = False, - expect_stripped_words: bool = True, # if True, need to add spaces between words + role: str = "user", + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self._messages = messages self._role = role - self._start_frame = start_frame - self._end_frame = end_frame - self._accumulator_frame = accumulator_frame - self._interim_accumulator_frame = interim_accumulator_frame - self._handle_interruptions = handle_interruptions - self._expect_stripped_words = expect_stripped_words - # Reset our accumulator state. - self._reset() + self._aggregation = "" + + self.reset() @property - def messages(self): + def messages(self) -> List[dict]: return self._messages @property - def role(self): + def role(self) -> str: return self._role - # - # Frame processor - # + def add_messages(self, messages): + self._messages.extend(messages) - # Use cases implemented: - # - # S: Start, E: End, T: Transcription, I: Interim, X: Text - # - # S E -> None - # S T E -> X - # S I T E -> X - # S I E T -> X - # S I E I T -> X - # S E T -> X - # S E I T -> X - # - # The following case would not be supported: - # - # S I E T1 I T2 -> X - # - # and T2 would be dropped. + def set_messages(self, messages): + self.reset() + self._messages.clear() + self._messages.extend(messages) - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) + def set_tools(self, tools): + pass - send_aggregation = False + def reset(self): + self._aggregation = "" - if isinstance(frame, self._start_frame): - self._aggregation = "" - self._aggregating = True - self._seen_start_frame = True - self._seen_end_frame = False - self._seen_interim_results = False - await self.push_frame(frame, direction) - elif isinstance(frame, self._end_frame): - self._seen_end_frame = True - self._seen_start_frame = False - - # We might have received the end frame but we might still be - # aggregating (i.e. we have seen interim results but not the final - # text). - self._aggregating = self._seen_interim_results or len(self._aggregation) == 0 - - # Send the aggregation if we are not aggregating anymore (i.e. no - # more interim results received). - send_aggregation = not self._aggregating - await self.push_frame(frame, direction) - elif isinstance(frame, self._accumulator_frame): - if self._aggregating: - if self._expect_stripped_words: - self._aggregation += f" {frame.text}" if self._aggregation else frame.text - else: - self._aggregation += frame.text - # We have recevied a complete sentence, so if we have seen the - # end frame and we were still aggregating, it means we should - # send the aggregation. - send_aggregation = self._seen_end_frame - - # We just got our final result, so let's reset interim results. - self._seen_interim_results = False - elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): - self._seen_interim_results = True - elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame): - await self._push_aggregation() - # Reset anyways - self._reset() - await self.push_frame(frame, direction) - elif isinstance(frame, LLMMessagesAppendFrame): - self._add_messages(frame.messages) - elif isinstance(frame, LLMMessagesUpdateFrame): - self._set_messages(frame.messages) - elif isinstance(frame, LLMSetToolsFrame): - self._set_tools(frame.tools) - else: - await self.push_frame(frame, direction) - - if send_aggregation: - await self._push_aggregation() - - async def _push_aggregation(self): + async def push_aggregation(self): if len(self._aggregation) > 0: self._messages.append({"role": self._role, "content": self._aggregation}) @@ -153,109 +146,27 @@ class LLMResponseAggregator(FrameProcessor): frame = LLMMessagesFrame(self._messages) await self.push_frame(frame) - # TODO-CB: Types - def _add_messages(self, messages): - self._messages.extend(messages) - def _set_messages(self, messages): - self._reset() - self._messages.clear() - self._messages.extend(messages) +class LLMContextResponseAggregator(BaseLLMResponseAggregator): + """This is a base LLM aggregator that uses an LLM context to store the + conversation. It pushes `OpenAILLMContextFrame` as an aggregation frame. - def _set_tools(self, tools): - # noop in the base class - pass - - def _reset(self): - self._aggregation = "" - self._aggregating = False - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - - -class LLMAssistantResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: List[dict] = []): - super().__init__( - messages=messages, - role="assistant", - start_frame=LLMFullResponseStartFrame, - end_frame=LLMFullResponseEndFrame, - accumulator_frame=TextFrame, - handle_interruptions=True, - ) - - -class LLMUserResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: List[dict] = []): - super().__init__( - messages=messages, - role="user", - start_frame=UserStartedSpeakingFrame, - end_frame=UserStoppedSpeakingFrame, - accumulator_frame=TranscriptionFrame, - interim_accumulator_frame=InterimTranscriptionFrame, - ) - - -class LLMFullResponseAggregator(FrameProcessor): - """This class aggregates Text frames until it receives a - LLMFullResponseEndFrame, then emits the concatenated text as - a single text frame. - - given the following frames: - - TextFrame("Hello,") - TextFrame(" world.") - TextFrame(" I am") - TextFrame(" an LLM.") - LLMFullResponseEndFrame()] - - this processor will yield nothing for the first 4 frames, then - - TextFrame("Hello, world. I am an LLM.") - LLMFullResponseEndFrame() - - when passed the last frame. - - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - ... else: - ... print(frame.__class__.__name__) - - >>> aggregator = LLMFullResponseAggregator() - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) - >>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame())) - Hello, world. I am an LLM. - LLMFullResponseEndFrame """ - def __init__(self): - super().__init__() - self._aggregation = "" - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - if isinstance(frame, TextFrame): - self._aggregation += frame.text - elif isinstance(frame, LLMFullResponseEndFrame): - await self.push_frame(TextFrame(self._aggregation)) - await self.push_frame(frame) - self._aggregation = "" - else: - await self.push_frame(frame, direction) - - -class LLMContextAggregator(LLMResponseAggregator): - def __init__(self, *, context: OpenAILLMContext, **kwargs): + def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs): super().__init__(**kwargs) self._context = context + self._role = role + + self._aggregation = "" + + @property + def messages(self) -> List[dict]: + return self._context.get_messages() + + @property + def role(self) -> str: + return self._role @property def context(self): @@ -264,23 +175,25 @@ class LLMContextAggregator(LLMResponseAggregator): def get_context_frame(self) -> OpenAILLMContextFrame: return OpenAILLMContextFrame(context=self._context) - async def push_context_frame(self): + async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM): frame = self.get_context_frame() - await self.push_frame(frame) + await self.push_frame(frame, direction) - # TODO-CB: Types - def _add_messages(self, messages): + def add_messages(self, messages): self._context.add_messages(messages) - def _set_messages(self, messages): + def set_messages(self, messages): self._context.set_messages(messages) - def _set_tools(self, tools: List): + def set_tools(self, tools: List): self._context.set_tools(tools) - async def _push_aggregation(self): + def reset(self): + self._aggregation = "" + + async def push_aggregation(self): if len(self._aggregation) > 0: - self._context.add_message({"role": self._role, "content": self._aggregation}) + self._context.add_message({"role": self.role, "content": self._aggregation}) # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. @@ -290,31 +203,236 @@ class LLMContextAggregator(LLMResponseAggregator): await self.push_frame(frame) # Reset our accumulator state. - self._reset() + self.reset() -class LLMAssistantContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True): - super().__init__( - messages=[], - context=context, - role="assistant", - start_frame=LLMFullResponseStartFrame, - end_frame=LLMFullResponseEndFrame, - accumulator_frame=TextFrame, - handle_interruptions=True, - expect_stripped_words=expect_stripped_words, - ) +class LLMUserContextAggregator(LLMContextResponseAggregator): + """This is a user LLM aggregator that uses an LLM context to store the + conversation. It aggregates transcriptions from the STT service and it has + logic to handle multiple scenarios where transcriptions are received between + VAD events (`UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame`) or + even outside or no VAD events at all. + + """ + + def __init__( + self, + context: OpenAILLMContext, + aggregation_timeout: float = 1.0, + bot_interruption_timeout: float = 2.0, + **kwargs, + ): + super().__init__(context=context, role="user", **kwargs) + self._aggregation_timeout = aggregation_timeout + self._bot_interruption_timeout = bot_interruption_timeout + + self._seen_interim_results = False + self._user_speaking = False + self._last_user_speaking_time = 0 + self._emulating_vad = False + + self._aggregation_event = asyncio.Event() + self._aggregation_task = None + + self.reset() + + def reset(self): + super().reset() + self._seen_interim_results = False + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartFrame): + await self._start(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, EndFrame): + await self._stop(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, CancelFrame): + await self._cancel(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, UserStartedSpeakingFrame): + await self._handle_user_started_speaking(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, UserStoppedSpeakingFrame): + await self._handle_user_stopped_speaking(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, TranscriptionFrame): + await self._handle_transcription(frame) + elif isinstance(frame, InterimTranscriptionFrame): + await self._handle_interim_transcription(frame) + elif isinstance(frame, LLMMessagesAppendFrame): + self.add_messages(frame.messages) + elif isinstance(frame, LLMMessagesUpdateFrame): + self.set_messages(frame.messages) + elif isinstance(frame, LLMSetToolsFrame): + self.set_tools(frame.tools) + else: + await self.push_frame(frame, direction) + + async def _start(self, frame: StartFrame): + self._create_aggregation_task() + + async def _stop(self, frame: EndFrame): + await self._cancel_aggregation_task() + + async def _cancel(self, frame: CancelFrame): + await self._cancel_aggregation_task() + + async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame): + self._user_speaking = True + + async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame): + self._last_user_speaking_time = time.time() + self._user_speaking = False + if not self._seen_interim_results: + await self.push_aggregation() + + async def _handle_transcription(self, frame: TranscriptionFrame): + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + # We just got a final result, so let's reset interim results. + self._seen_interim_results = False + # Reset aggregation timer. + self._aggregation_event.set() + + async def _handle_interim_transcription(self, _: InterimTranscriptionFrame): + self._seen_interim_results = True + # Reset aggregation timer. + self._aggregation_event.set() + + def _create_aggregation_task(self): + self._aggregation_task = self.create_task(self._aggregation_task_handler()) + + async def _cancel_aggregation_task(self): + if self._aggregation_task: + await self.cancel_task(self._aggregation_task) + self._aggregation_task = None + + async def _aggregation_task_handler(self): + while True: + try: + await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout) + await self._maybe_push_bot_interruption() + except asyncio.TimeoutError: + if not self._user_speaking: + await self.push_aggregation() + + # If we are emulating VAD we still need to send the user stopped + # speaking frame. + if self._emulating_vad: + await self.push_frame( + EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM + ) + self._emulating_vad = False + finally: + self._aggregation_event.clear() + + async def _maybe_push_bot_interruption(self): + """If the user stopped speaking a while back and we got a transcription + frame we might want to interrupt the bot. + + """ + if not self._user_speaking: + diff_time = time.time() - self._last_user_speaking_time + if diff_time > self._bot_interruption_timeout: + # If we reach this case we received a transcription but VAD was + # not able to detect voice (e.g. when you whisper a short + # utterance). So, we need to emulate VAD (i.e. user + # start/stopped speaking). + await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM) + self._emulating_vad = True + + # Reset time so we don't interrupt again right away. + self._last_user_speaking_time = time.time() -class LLMUserContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext): - super().__init__( - messages=[], - context=context, - role="user", - start_frame=UserStartedSpeakingFrame, - end_frame=UserStoppedSpeakingFrame, - accumulator_frame=TranscriptionFrame, - interim_accumulator_frame=InterimTranscriptionFrame, - ) +class LLMAssistantContextAggregator(LLMContextResponseAggregator): + """This is an assistant LLM aggregator that uses an LLM context to store the + conversation. It aggregates text frames received between + `LLMFullResponseStartFrame` and `LLMFullResponseEndFrame`. + + """ + + def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs): + super().__init__(context=context, role="assistant", **kwargs) + self._expect_stripped_words = expect_stripped_words + + self.reset() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartInterruptionFrame): + await self.push_aggregation() + # Reset anyways + self.reset() + await self.push_frame(frame, direction) + elif isinstance(frame, LLMFullResponseStartFrame): + await self._handle_llm_start(frame) + elif isinstance(frame, LLMFullResponseEndFrame): + await self._handle_llm_end(frame) + elif isinstance(frame, TextFrame): + await self._handle_text(frame) + elif isinstance(frame, LLMMessagesAppendFrame): + self.add_messages(frame.messages) + elif isinstance(frame, LLMMessagesUpdateFrame): + self.set_messages(frame.messages) + elif isinstance(frame, LLMSetToolsFrame): + self.set_tools(frame.tools) + else: + await self.push_frame(frame, direction) + + async def _handle_llm_start(self, _: LLMFullResponseStartFrame): + self._started = True + + async def _handle_llm_end(self, _: LLMFullResponseEndFrame): + self._started = False + await self.push_aggregation() + + async def _handle_text(self, frame: TextFrame): + if not self._started: + return + + if self._expect_stripped_words: + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + else: + self._aggregation += frame.text + + +class LLMUserResponseAggregator(LLMUserContextAggregator): + def __init__(self, messages: List[dict] = [], **kwargs): + super().__init__(context=OpenAILLMContext(messages), **kwargs) + + async def push_aggregation(self): + if len(self._aggregation) > 0: + self._context.add_message({"role": self.role, "content": self._aggregation}) + + # Reset the aggregation. Reset it before pushing it down, otherwise + # if the tasks gets cancelled we won't be able to clear things up. + self._aggregation = "" + + frame = LLMMessagesFrame(self._context.messages) + await self.push_frame(frame) + + # Reset our accumulator state. + self.reset() + + +class LLMAssistantResponseAggregator(LLMAssistantContextAggregator): + def __init__(self, messages: List[dict], **kwargs): + super().__init__(context=OpenAILLMContext(messages), **kwargs) + + async def push_aggregation(self): + if len(self._aggregation) > 0: + self._context.add_message({"role": self.role, "content": self._aggregation}) + + # Reset the aggregation. Reset it before pushing it down, otherwise + # if the tasks gets cancelled we won't be able to clear things up. + self._aggregation = "" + + frame = LLMMessagesFrame(self._context.messages) + await self.push_frame(frame) + + # Reset our accumulator state. + self.reset() diff --git a/src/pipecat/processors/aggregators/user_response.py b/src/pipecat/processors/aggregators/user_response.py index 0e2c09ebc..6998fe200 100644 --- a/src/pipecat/processors/aggregators/user_response.py +++ b/src/pipecat/processors/aggregators/user_response.py @@ -4,131 +4,15 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Optional - -from pipecat.frames.frames import ( - Frame, - InterimTranscriptionFrame, - StartInterruptionFrame, - TextFrame, - TranscriptionFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, -) -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import TextFrame +from pipecat.processors.aggregators.llm_response import LLMUserResponseAggregator -class ResponseAggregator(FrameProcessor): - """This frame processor aggregates frames between a start and an end frame - into complete text frame sentences. +class UserResponseAggregator(LLMUserResponseAggregator): + def __init__(self, **kwargs): + super().__init__(**kwargs) - For example, frame input/output: - UserStartedSpeakingFrame() -> None - TranscriptionFrame("Hello,") -> None - TranscriptionFrame(" world.") -> None - UserStoppedSpeakingFrame() -> TextFrame("Hello world.") - - Doctest: FIXME to work with asyncio - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - - >>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame, - ... end_frame=UserStoppedSpeakingFrame, - ... accumulator_frame=TranscriptionFrame, - ... pass_through=False) - >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) - >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) - >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) - >>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame())) - Hello, world. - - """ - - def __init__( - self, - *, - start_frame, - end_frame, - accumulator_frame: TextFrame, - interim_accumulator_frame: Optional[TextFrame] = None, - ): - super().__init__() - - self._start_frame = start_frame - self._end_frame = end_frame - self._accumulator_frame = accumulator_frame - self._interim_accumulator_frame = interim_accumulator_frame - - # Reset our accumulator state. - self._reset() - - # - # Frame processor - # - - # Use cases implemented: - # - # S: Start, E: End, T: Transcription, I: Interim, X: Text - # - # S E -> None - # S T E -> X - # S I T E -> X - # S I E T -> X - # S I E I T -> X - # S E T -> X - # S E I T -> X - # - # The following case would not be supported: - # - # S I E T1 I T2 -> X - # - # and T2 would be dropped. - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - send_aggregation = False - - if isinstance(frame, self._start_frame): - self._aggregating = True - self._seen_start_frame = True - self._seen_end_frame = False - self._seen_interim_results = False - await self.push_frame(frame, direction) - elif isinstance(frame, self._end_frame): - self._seen_end_frame = True - self._seen_start_frame = False - - # We might have received the end frame but we might still be - # aggregating (i.e. we have seen interim results but not the final - # text). - self._aggregating = self._seen_interim_results or len(self._aggregation) == 0 - - # Send the aggregation if we are not aggregating anymore (i.e. no - # more interim results received). - send_aggregation = not self._aggregating - await self.push_frame(frame, direction) - elif isinstance(frame, self._accumulator_frame): - if self._aggregating: - self._aggregation += f" {frame.text}" - # We have recevied a complete sentence, so if we have seen the - # end frame and we were still aggregating, it means we should - # send the aggregation. - send_aggregation = self._seen_end_frame - - # We just got our final result, so let's reset interim results. - self._seen_interim_results = False - elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): - self._seen_interim_results = True - else: - await self.push_frame(frame, direction) - - if send_aggregation: - await self._push_aggregation() - - async def _push_aggregation(self): + async def push_aggregation(self): if len(self._aggregation) > 0: frame = TextFrame(self._aggregation.strip()) @@ -139,21 +23,4 @@ class ResponseAggregator(FrameProcessor): await self.push_frame(frame) # Reset our accumulator state. - self._reset() - - def _reset(self): - self._aggregation = "" - self._aggregating = False - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - - -class UserResponseAggregator(ResponseAggregator): - def __init__(self): - super().__init__( - start_frame=UserStartedSpeakingFrame, - end_frame=UserStoppedSpeakingFrame, - accumulator_frame=TranscriptionFrame, - interim_accumulator_frame=InterimTranscriptionFrame, - ) + self.reset() diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index b8d223176..ac1c4582d 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -577,7 +577,7 @@ class SegmentedSTTService(STTService): self._smoothing_factor = 0.2 self._prev_volume = 0 - async def process_audio_frame(self, frame: AudioRawFrame): + async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection): # Try to filter out empty background noise volume = self._get_smoothed_volume(frame) if volume >= self._min_volume: diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index a593ced89..80235199b 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -126,9 +126,11 @@ class AnthropicLLMService(LLMService): def create_context_aggregator( context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True ) -> AnthropicContextAggregatorPair: + if isinstance(context, OpenAILLMContext): + context = AnthropicLLMContext.from_openai_context(context) user = AnthropicUserContextAggregator(context) assistant = AnthropicAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) @@ -651,11 +653,8 @@ class AnthropicLLMContext(OpenAILLMContext): class AnthropicUserContextAggregator(LLMUserContextAggregator): - def __init__(self, context: OpenAILLMContext | AnthropicLLMContext): - super().__init__(context=context) - - if isinstance(context, OpenAILLMContext): - self._context = AnthropicLLMContext.from_openai_context(context) + def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs): + super().__init__(context=context, **kwargs) async def process_frame(self, frame, direction): await super().process_frame(frame, direction) @@ -703,9 +702,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator): class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs): - super().__init__(context=user_context_aggregator._context, **kwargs) - self._user_context_aggregator = user_context_aggregator + def __init__(self, context: OpenAILLMContext | AnthropicLLMContext, **kwargs): + super().__init__(context=context, **kwargs) self._function_call_in_progress = None self._function_call_result = None self._pending_image_frame_message = None @@ -725,7 +723,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): ): self._function_call_in_progress = None self._function_call_result = frame - await self._push_aggregation() + await self.push_aggregation() else: logger.warning( "FunctionCallResultFrame tool_call_id != InProgressFrame tool_call_id" @@ -734,9 +732,9 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): self._function_call_result = None elif isinstance(frame, AnthropicImageMessageFrame): self._pending_image_frame_message = frame - await self._push_aggregation() + await self.push_aggregation() - async def _push_aggregation(self): + async def push_aggregation(self): if not ( self._aggregation or self._function_call_result or self._pending_image_frame_message ): @@ -746,7 +744,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): properties: Optional[FunctionCallResultProperties] = None aggregation = self._aggregation - self._reset() + self.reset() try: if self._function_call_result: @@ -799,15 +797,14 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: await properties.on_context_updated() # Push context frame - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() # Push timestamp frame with current time timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 4f26b6e9f..6e7a1c0fa 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -115,10 +115,10 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator): class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator): - async def _push_aggregation(self): + async def push_aggregation(self): # We don't want to store any images in the context. Revisit this later when the API evolves. self._pending_image_frame_message = None - await super()._push_aggregation() + await super().push_aggregation() @dataclass @@ -706,6 +706,6 @@ class GeminiMultimodalLiveLLMService(LLMService): GeminiMultimodalLiveContext.upgrade(context) user = GeminiMultimodalLiveUserContextAggregator(context) assistant = GeminiMultimodalLiveAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index 28cd0d421..605c74069 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -537,7 +537,7 @@ def language_to_google_stt_language(language: Language) -> Optional[str]: class GoogleUserContextAggregator(OpenAIUserContextAggregator): - async def _push_aggregation(self): + async def push_aggregation(self): if len(self._aggregation) > 0: self._context.add_message( glm.Content(role="user", parts=[glm.Part(text=self._aggregation)]) @@ -552,11 +552,11 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator): await self.push_frame(frame) # Reset our accumulator state. - self._reset() + self.reset() class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): - async def _push_aggregation(self): + async def push_aggregation(self): if not ( self._aggregation or self._function_call_result or self._pending_image_frame_message ): @@ -566,7 +566,7 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): properties: Optional[FunctionCallResultProperties] = None aggregation = self._aggregation - self._reset() + self.reset() try: if self._function_call_result: @@ -626,15 +626,14 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: await properties.on_context_updated() # Push context frame - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() # Push timestamp frame with current time timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) @@ -1175,7 +1174,7 @@ class GoogleLLMService(LLMService): ) -> GoogleContextAggregatorPair: user = GoogleUserContextAggregator(context) assistant = GoogleAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return GoogleContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok.py index 7221cc09e..064dd8829 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok.py @@ -17,6 +17,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) +from pipecat.processors.frame_processor import FrameDirection from pipecat.services.openai import ( OpenAIAssistantContextAggregator, OpenAILLMService, @@ -27,7 +28,7 @@ from pipecat.services.openai import ( class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): """Custom assistant context aggregator for Grok that handles empty content requirement.""" - async def _push_aggregation(self): + async def push_aggregation(self): if not ( self._aggregation or self._function_call_result or self._pending_image_frame_message ): @@ -37,7 +38,7 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): properties: Optional[FunctionCallResultProperties] = None aggregation = self._aggregation - self._reset() + self.reset() try: if self._function_call_result: @@ -91,14 +92,13 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: await properties.on_context_updated() - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() except Exception as e: logger.error(f"Error processing frame: {e}") @@ -212,6 +212,6 @@ class GrokLLMService(OpenAILLMService): ) -> GrokContextAggregatorPair: user = OpenAIUserContextAggregator(context) assistant = GrokAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return GrokContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index bc251025e..b0a66eba0 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -355,7 +355,7 @@ class OpenAILLMService(BaseOpenAILLMService): ) -> OpenAIContextAggregatorPair: user = OpenAIUserContextAggregator(context) assistant = OpenAIAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) @@ -555,8 +555,8 @@ class OpenAIImageMessageFrame(Frame): class OpenAIUserContextAggregator(LLMUserContextAggregator): - def __init__(self, context: OpenAILLMContext): - super().__init__(context=context) + def __init__(self, context: OpenAILLMContext, **kwargs): + super().__init__(context=context, **kwargs) async def process_frame(self, frame, direction): await super().process_frame(frame, direction) @@ -592,9 +592,8 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator): class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs): - super().__init__(context=user_context_aggregator._context, **kwargs) - self._user_context_aggregator = user_context_aggregator + def __init__(self, context: OpenAILLMContext, **kwargs): + super().__init__(context=context, **kwargs) self._function_calls_in_progress = {} self._function_call_result = None self._pending_image_frame_message = None @@ -614,7 +613,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): del self._function_calls_in_progress[frame.tool_call_id] self._function_call_result = frame # TODO-CB: Kwin wants us to refactor this out of here but I REFUSE - await self._push_aggregation() + await self.push_aggregation() else: logger.warning( "FunctionCallResultFrame tool_call_id does not match any function call in progress" @@ -622,9 +621,9 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): self._function_call_result = None elif isinstance(frame, OpenAIImageMessageFrame): self._pending_image_frame_message = frame - await self._push_aggregation() + await self.push_aggregation() - async def _push_aggregation(self): + async def push_aggregation(self): if not ( self._aggregation or self._function_call_result or self._pending_image_frame_message ): @@ -634,7 +633,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): properties: Optional[FunctionCallResultProperties] = None aggregation = self._aggregation - self._reset() + self.reset() try: if self._function_call_result: @@ -686,15 +685,14 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): run_llm = True if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: await properties.on_context_updated() # Push context frame - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() # Push timestamp frame with current time timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) diff --git a/src/pipecat/services/openai_realtime_beta/context.py b/src/pipecat/services/openai_realtime_beta/context.py index da287194a..31639dc6b 100644 --- a/src/pipecat/services/openai_realtime_beta/context.py +++ b/src/pipecat/services/openai_realtime_beta/context.py @@ -166,7 +166,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): if isinstance(frame, LLMSetToolsFrame): await self.push_frame(frame, direction) - async def _push_aggregation(self): + async def push_aggregation(self): # for the moment, ignore all user input coming into the pipeline. # todo: think about whether/how to fix this to allow for text input from # upstream (transport/transcription, or other sources) @@ -174,7 +174,7 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator): - async def _push_aggregation(self): + async def push_aggregation(self): # the only thing we implement here is function calling. in all other cases, messages # are added to the context when we receive openai realtime api events if not self._function_call_result: @@ -182,7 +182,7 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) properties: Optional[FunctionCallResultProperties] = None - self._reset() + self.reset() try: run_llm = True frame = self._function_call_result @@ -217,8 +217,8 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) # The standard function callback code path pushes the FunctionCallResultFrame from the llm itself, # so we didn't have a chance to add the result to the openai realtime api context. Let's push a # special frame to do that. - await self._user_context_aggregator.push_frame( - RealtimeFunctionCallResultFrame(result_frame=frame) + await self.push_frame( + RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM ) if properties and properties.run_llm is not None: # If the tool call result has a run_llm property, use it @@ -228,14 +228,13 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator) run_llm = not bool(self._function_calls_in_progress) if run_llm: - await self._user_context_aggregator.push_context_frame() + await self.push_context_frame(FrameDirection.UPSTREAM) # Emit the on_context_updated callback once the function call result is added to the context if properties and properties.on_context_updated is not None: await properties.on_context_updated() - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) + await self.push_context_frame() except Exception as e: logger.error(f"Error processing frame: {e}") diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index a42ed7b9a..bc173d765 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -568,6 +568,6 @@ class OpenAIRealtimeBetaLLMService(LLMService): OpenAIRealtimeLLMContext.upgrade_to_realtime(context) user = OpenAIRealtimeUserContextAggregator(context) assistant = OpenAIRealtimeAssistantContextAggregator( - user, expect_stripped_words=assistant_expect_stripped_words + context, expect_stripped_words=assistant_expect_stripped_words ) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/tests/utils.py b/src/pipecat/tests/utils.py index ff92164a4..55bda9cea 100644 --- a/src/pipecat/tests/utils.py +++ b/src/pipecat/tests/utils.py @@ -5,6 +5,7 @@ # import asyncio +from dataclasses import dataclass from typing import Any, Awaitable, Callable, Dict, Sequence, Tuple from pipecat.frames.frames import ( @@ -12,6 +13,7 @@ from pipecat.frames.frames import ( Frame, HeartbeatFrame, StartFrame, + SystemFrame, ) from pipecat.observers.base_observer import BaseObserver from pipecat.pipeline.pipeline import Pipeline @@ -20,6 +22,15 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +@dataclass +class SleepFrame(SystemFrame): + """This frame is used by test framework to introduce some sleep time before + the next frame is pushed. This is useful to control system frames vs data or + control frames.""" + + sleep: float = 0.1 + + class HeartbeatsObserver(BaseObserver): def __init__( self, @@ -44,7 +55,11 @@ class HeartbeatsObserver(BaseObserver): class QueuedFrameProcessor(FrameProcessor): def __init__( - self, queue: asyncio.Queue, queue_direction: FrameDirection, ignore_start: bool = True + self, + *, + queue: asyncio.Queue, + queue_direction: FrameDirection, + ignore_start: bool = True, ): super().__init__() self._queue = queue @@ -72,21 +87,35 @@ async def run_test( ) -> Tuple[Sequence[Frame], Sequence[Frame]]: received_up = asyncio.Queue() received_down = asyncio.Queue() - source = QueuedFrameProcessor(received_up, FrameDirection.UPSTREAM, ignore_start) - sink = QueuedFrameProcessor(received_down, FrameDirection.DOWNSTREAM, ignore_start) + source = QueuedFrameProcessor( + queue=received_up, + queue_direction=FrameDirection.UPSTREAM, + ignore_start=ignore_start, + ) + sink = QueuedFrameProcessor( + queue=received_down, + queue_direction=FrameDirection.DOWNSTREAM, + ignore_start=ignore_start, + ) pipeline = Pipeline([source, processor, sink]) task = PipelineTask(pipeline, params=PipelineParams(start_metadata=start_metadata)) - for frame in frames_to_send: - await task.queue_frame(frame) + async def push_frames(): + # Just give a little head start to the runner. + await asyncio.sleep(0.01) + for frame in frames_to_send: + if isinstance(frame, SleepFrame): + await asyncio.sleep(frame.sleep) + else: + await task.queue_frame(frame) - if send_end_frame: - await task.queue_frame(EndFrame()) + if send_end_frame: + await task.queue_frame(EndFrame()) runner = PipelineRunner() - await runner.run(task) + await asyncio.gather(runner.run(task), push_frames()) # # Down frames @@ -98,6 +127,7 @@ async def run_test( received_down_frames.append(frame) print("received DOWN frames =", received_down_frames) + print("expected DOWN frames =", expected_down_frames) assert len(received_down_frames) == len(expected_down_frames) @@ -113,6 +143,7 @@ async def run_test( received_up_frames.append(frame) print("received UP frames =", received_up_frames) + print("expected UP frames =", expected_up_frames) assert len(received_up_frames) == len(expected_up_frames) diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 6bfe86001..42eb162da 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -14,6 +14,8 @@ from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState from pipecat.frames.frames import ( BotInterruptionFrame, CancelFrame, + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, EndFrame, FilterUpdateSettingsFrame, Frame, @@ -112,9 +114,13 @@ class BaseInputTransport(FrameProcessor): await self.cancel(frame) await self.push_frame(frame, direction) elif isinstance(frame, BotInterruptionFrame): - logger.debug("Bot interruption") - await self._start_interruption() - await self.push_frame(StartInterruptionFrame()) + await self._handle_bot_interruption(frame) + elif isinstance(frame, EmulateUserStartedSpeakingFrame): + logger.debug("Emulating user started speaking") + await self._handle_user_interruption(UserStartedSpeakingFrame()) + elif isinstance(frame, EmulateUserStoppedSpeakingFrame): + logger.debug("Emulating user stopped speaking") + await self._handle_user_interruption(UserStoppedSpeakingFrame()) # All other system frames elif isinstance(frame, SystemFrame): await self.push_frame(frame, direction) @@ -137,7 +143,13 @@ class BaseInputTransport(FrameProcessor): # Handle interruptions # - async def _handle_interruptions(self, frame: Frame): + async def _handle_bot_interruption(self, frame: BotInterruptionFrame): + logger.debug("Bot interruption") + if self.interruptions_allowed: + await self._start_interruption() + await self.push_frame(StartInterruptionFrame()) + + async def _handle_user_interruption(self, frame: Frame): if isinstance(frame, UserStartedSpeakingFrame): logger.debug("User started speaking") # Make sure we notify about interruptions quickly out-of-band. @@ -183,7 +195,7 @@ class BaseInputTransport(FrameProcessor): frame = UserStoppedSpeakingFrame() if frame: - await self._handle_interruptions(frame) + await self._handle_user_interruption(frame) vad_state = new_vad_state return vad_state diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 6bf4a9ff3..1b6d1e833 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -170,6 +170,8 @@ class BaseOutputTransport(FrameProcessor): # TODO(aleix): Images and audio should support presentation timestamps. elif frame.pts: await self._sink_clock_queue.put((frame.pts, frame.id, frame)) + elif direction == FrameDirection.UPSTREAM: + await self.push_frame(frame, direction) else: await self._sink_queue.put(frame) diff --git a/test-requirements.txt b/test-requirements.txt index 999c7a6ba..7566de351 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -7,8 +7,10 @@ deepgram-sdk~=3.5.0 fal-client~=0.4.1 fastapi~=0.115.0 faster-whisper~=1.0.3 -google-cloud-texttospeech~=2.21.1 -google-generativeai~=0.8.3 +google-cloud-speech~=2.31.0 +google-cloud-texttospeech~=2.25.0 +google-genai~=1.2.0 +google-generativeai~=0.8.4 langchain~=0.2.14 livekit~=0.13.1 lmnt~=1.1.4 diff --git a/tests/test_aggregators.py b/tests/test_aggregators.py index ab7210e49..48650977e 100644 --- a/tests/test_aggregators.py +++ b/tests/test_aggregators.py @@ -29,12 +29,12 @@ class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase): for word in sentence.split(" "): frames_to_send.append(TextFrame(text=word + " ")) - expected_returned_frames = [TextFrame, TextFrame, TextFrame] + expected_down_frames = [TextFrame, TextFrame, TextFrame] (received_down, _) = await run_test( aggregator, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) assert received_down[-3].text == "Hello, world. " assert received_down[-2].text == "How are you? " @@ -59,7 +59,7 @@ class TestGatedAggregator(unittest.IsolatedAsyncioTestCase): LLMFullResponseEndFrame(), ] - expected_returned_frames = [ + expected_down_frames = [ OutputImageRawFrame, LLMFullResponseStartFrame, TextFrame, @@ -72,5 +72,5 @@ class TestGatedAggregator(unittest.IsolatedAsyncioTestCase): (received_down, _) = await run_test( gated_aggregator, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py new file mode 100644 index 000000000..d4b8c35ce --- /dev/null +++ b/tests/test_context_aggregators.py @@ -0,0 +1,671 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +import google.ai.generativelanguage as glm + +from pipecat.frames.frames import ( + EmulateUserStartedSpeakingFrame, + EmulateUserStoppedSpeakingFrame, + InterimTranscriptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + OpenAILLMContextAssistantTimestampFrame, + StartInterruptionFrame, + TextFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMUserContextAggregator, +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.services.anthropic import ( + AnthropicAssistantContextAggregator, + AnthropicLLMContext, + AnthropicUserContextAggregator, +) +from pipecat.services.google.google import ( + GoogleAssistantContextAggregator, + GoogleLLMContext, + GoogleUserContextAggregator, +) +from pipecat.services.openai import OpenAIAssistantContextAggregator, OpenAIUserContextAggregator +from pipecat.tests.utils import SleepFrame, run_test + +AGGREGATION_TIMEOUT = 0.1 +AGGREGATION_SLEEP = 0.15 +BOT_INTERRUPTION_TIMEOUT = 0.2 +BOT_INTERRUPTION_SLEEP = 0.25 + + +class BaseTestUserContextAggregator: + CONTEXT_CLASS = None # To be set in subclasses + AGGREGATOR_CLASS = None # To be set in subclasses + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] + + def check_message_content(self, context: OpenAILLMContext, index: int, content: str): + assert context.messages[index]["content"] == content + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + assert context.messages[index]["content"] == content + + async def test_se(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()] + expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + async def test_ste(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "Hello!") + + async def test_site(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "Hello Pipecat!") + + async def test_st1iest2e(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "Hello Pipecat! How are you?") + + async def test_siet(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "How are you?") + + async def test_sieit(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "How are you?") + + async def test_set(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "How are you?") + + async def test_seit(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "How are you?") + + async def test_st1et2(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_multi_content(context, 0, 0, "Hello Pipecat!") + self.check_message_multi_content(context, 0, 1, "How are you?") + + async def test_set1t2(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "Hello Pipecat! How are you?") + + async def test_siet1it2(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "Hello Pipecat! How are you?") + + async def test_t(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + self.check_message_content(context, 0, "Hello!") + + async def test_it(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, aggregation_timeout=AGGREGATION_TIMEOUT) + frames_to_send = [ + InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), + SleepFrame(), + TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + self.check_message_content(context, 0, "Hello Pipecat!") + + async def test_sie_delay_it(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS( + context, + aggregation_timeout=AGGREGATION_TIMEOUT, + bot_interruption_timeout=BOT_INTERRUPTION_TIMEOUT, + ) + frames_to_send = [ + UserStartedSpeakingFrame(), + InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), + SleepFrame(), + UserStoppedSpeakingFrame(), + SleepFrame(BOT_INTERRUPTION_SLEEP), + InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""), + TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), + SleepFrame(sleep=AGGREGATION_SLEEP), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + expected_up_frames=expected_up_frames, + ) + self.check_message_content(context, 0, "How are you?") + + +class BaseTestAssistantContextAggreagator: + CONTEXT_CLASS = None # To be set in subclasses + AGGREGATOR_CLASS = None # To be set in subclasses + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] + + def check_message_content(self, context: OpenAILLMContext, index: int, content: str): + assert context.messages[index]["content"] == content + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + assert context.messages[index]["content"] == content + + async def test_empty(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()] + expected_down_frames = [] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + async def test_single_text(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello Pipecat!"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "Hello Pipecat!") + + async def test_multiple_text(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello "), + TextFrame(text="Pipecat. "), + TextFrame(text="How are "), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "Hello Pipecat. How are you?") + + async def test_multiple_text_stripped(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello"), + TextFrame(text="Pipecat."), + TextFrame(text="How are"), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_content(context, 0, "Hello Pipecat. How are you?") + + async def test_multiple_llm_responses(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello "), + TextFrame(text="Pipecat."), + LLMFullResponseEndFrame(), + LLMFullResponseStartFrame(), + TextFrame(text="How are "), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_multi_content(context, 0, 0, "Hello Pipecat.") + self.check_message_multi_content(context, 0, 1, "How are you?") + + async def test_multiple_llm_responses_interruption(self): + assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" + + context = self.CONTEXT_CLASS() + aggregator = self.AGGREGATOR_CLASS(context, expect_stripped_words=False) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello "), + TextFrame(text="Pipecat."), + LLMFullResponseEndFrame(), + SleepFrame(AGGREGATION_SLEEP), + StartInterruptionFrame(), + LLMFullResponseStartFrame(), + TextFrame(text="How are "), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [ + *self.EXPECTED_CONTEXT_FRAMES, + StartInterruptionFrame, + *self.EXPECTED_CONTEXT_FRAMES, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + self.check_message_multi_content(context, 0, 0, "Hello Pipecat.") + self.check_message_multi_content(context, 0, 1, "How are you?") + + +# +# LLMUserContextAggregator, LLMAssistantContextAggregator +# + + +class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = LLMUserContextAggregator + + +class TestLLMAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = LLMAssistantContextAggregator + + +# +# OpenAI +# + + +class TestOpenAIUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = OpenAIUserContextAggregator + + +class TestOpenAIAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = OpenAIAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] + + +# +# Anthropic +# + + +class TestAnthropicUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = AnthropicLLMContext + AGGREGATOR_CLASS = AnthropicUserContextAggregator + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + messages = context.messages[content_index] + assert messages["content"][index]["text"] == content + + +class TestAnthropicAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = AnthropicLLMContext + AGGREGATOR_CLASS = AnthropicAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + messages = context.messages[content_index] + assert messages["content"][index]["text"] == content + + +# +# Google +# + + +class TestGoogleUserContextAggregator( + BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = GoogleLLMContext + AGGREGATOR_CLASS = GoogleUserContextAggregator + + def check_message_content(self, context: OpenAILLMContext, index: int, content: str): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["text"] == content + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["text"] == content + + +class TestGoogleAssistantContextAggregator( + BaseTestAssistantContextAggreagator, unittest.IsolatedAsyncioTestCase +): + CONTEXT_CLASS = GoogleLLMContext + AGGREGATOR_CLASS = GoogleAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] + + def check_message_content(self, context: OpenAILLMContext, index: int, content: str): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["text"] == content + + def check_message_multi_content( + self, context: OpenAILLMContext, content_index: int, index: int, content: str + ): + obj = glm.Content.to_dict(context.messages[index]) + assert obj["parts"][0]["text"] == content diff --git a/tests/test_filters.py b/tests/test_filters.py index 9d3b0f003..a47903232 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -25,11 +25,11 @@ class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase): async def test_identity(self): filter = IdentityFilter() frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()] - expected_returned_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] + expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) @@ -37,32 +37,32 @@ class TestFrameFilter(unittest.IsolatedAsyncioTestCase): async def test_text_frame(self): filter = FrameFilter(types=(TextFrame,)) frames_to_send = [TextFrame(text="Hello Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_end_frame(self): filter = FrameFilter(types=(EndFrame,)) frames_to_send = [EndFrame()] - expected_returned_frames = [EndFrame] + expected_down_frames = [EndFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, send_end_frame=False, ) async def test_system_frame(self): filter = FrameFilter(types=()) frames_to_send = [UserStartedSpeakingFrame()] - expected_returned_frames = [UserStartedSpeakingFrame] + expected_down_frames = [UserStartedSpeakingFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) @@ -73,11 +73,11 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase): filter = FunctionFilter(filter=passthrough) frames_to_send = [TextFrame(text="Hello Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_no_passthrough(self): @@ -86,11 +86,11 @@ class TestFunctionFilter(unittest.IsolatedAsyncioTestCase): filter = FunctionFilter(filter=no_passthrough) frames_to_send = [TextFrame(text="Hello Pipecat!")] - expected_returned_frames = [] + expected_down_frames = [] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) @@ -98,11 +98,11 @@ class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase): async def test_no_wake_word(self): filter = WakeCheckFilter(wake_phrases=["Hey, Pipecat"]) frames_to_send = [TranscriptionFrame(user_id="test", text="Phrase 1", timestamp="")] - expected_returned_frames = [] + expected_down_frames = [] await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_wake_word(self): @@ -111,10 +111,10 @@ class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase): TranscriptionFrame(user_id="test", text="Hey, Pipecat", timestamp=""), TranscriptionFrame(user_id="test", text="Phrase 1", timestamp=""), ] - expected_returned_frames = [TranscriptionFrame, TranscriptionFrame] + expected_down_frames = [TranscriptionFrame, TranscriptionFrame] (received_down, _) = await run_test( filter, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) assert received_down[-1].text == "Phrase 1" diff --git a/tests/test_langchain.py b/tests/test_langchain.py index c94bc1c01..6534f6bb0 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -10,23 +10,22 @@ from langchain.prompts import ChatPromptTemplate from langchain_core.language_models import FakeStreamingListLLM from pipecat.frames.frames import ( - EndFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, + LLMMessagesFrame, TextFrame, TranscriptionFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) from pipecat.pipeline.pipeline import Pipeline -from pipecat.pipeline.runner import PipelineRunner -from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.llm_response import ( LLMAssistantResponseAggregator, LLMUserResponseAggregator, ) from pipecat.processors.frame_processor import FrameProcessor from pipecat.processors.frameworks.langchain import LangchainProcessor +from pipecat.tests.utils import SleepFrame, run_test class TestLangchain(unittest.IsolatedAsyncioTestCase): @@ -64,31 +63,26 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase): self.mock_proc = self.MockProcessor("token_collector") tma_in = LLMUserResponseAggregator(messages) - tma_out = LLMAssistantResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages, expect_stripped_words=False) - pipeline = Pipeline( - [ - tma_in, - proc, - self.mock_proc, - tma_out, - ] + pipeline = Pipeline([tma_in, proc, self.mock_proc, tma_out]) + + frames_to_send = [ + UserStartedSpeakingFrame(), + TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"), + SleepFrame(), + UserStoppedSpeakingFrame(), + ] + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + LLMMessagesFrame, + ] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, ) - task = PipelineTask(pipeline, PipelineParams(allow_interruptions=False)) - await task.queue_frames( - [ - UserStartedSpeakingFrame(), - TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"), - UserStoppedSpeakingFrame(), - EndFrame(), - ] - ) - - runner = PipelineRunner() - await runner.run(task) self.assertEqual("".join(self.mock_proc.token), self.expected_response) - # TODO: Address this issue - # This next one would fail with: - # AssertionError: ' H e l l o d e a r h u m a n' != 'Hello dear human' - # self.assertEqual(tma_out.messages[-1]["content"], self.expected_response) + self.assertEqual(tma_out.messages[-1]["content"], self.expected_response) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index caac89c11..0aff922b2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -21,11 +21,11 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase): pipeline = Pipeline([IdentityFilter()]) frames_to_send = [TextFrame(text="Hello from Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_pipeline_multiple(self): @@ -36,22 +36,22 @@ class TestPipeline(unittest.IsolatedAsyncioTestCase): pipeline = Pipeline([identity1, identity2, identity3]) frames_to_send = [TextFrame(text="Hello from Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_pipeline_start_metadata(self): pipeline = Pipeline([IdentityFilter()]) frames_to_send = [] - expected_returned_frames = [StartFrame] + expected_down_frames = [StartFrame] (received_down, _) = await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ignore_start=False, start_metadata={"foo": "bar"}, ) @@ -63,11 +63,11 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase): pipeline = ParallelPipeline([IdentityFilter()]) frames_to_send = [TextFrame(text="Hello from Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, ) async def test_parallel_multiple(self): @@ -75,11 +75,11 @@ class TestParallelPipeline(unittest.IsolatedAsyncioTestCase): pipeline = ParallelPipeline([IdentityFilter()], [IdentityFilter()]) frames_to_send = [TextFrame(text="Hello from Pipecat!")] - expected_returned_frames = [TextFrame] + expected_down_frames = [TextFrame] await run_test( pipeline, frames_to_send=frames_to_send, - expected_down_frames=expected_returned_frames, + expected_down_frames=expected_down_frames, )