diff --git a/CHANGELOG.md b/CHANGELOG.md index 076a2c93e..c58fb5cf5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added an asyncio event `finished_event` field to `InterruptionFrame`. When - assigned, the asyncio event will be set when the frame reaches the end of the - pipeline. You can use this field to know when an interruption made it all the - way to the end of a pipeline. The field has been also added to - `InterruptionTaskFrame`. - - Added support for loading external observers. You can now register custom pipeline observers by setting the `PIPECAT_OBSERVER_FILES` environment variable. This variable should contain a colon-separated list of Python files @@ -34,7 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `CancelFrame` and `CancelTaskFrame` have an optional `reason` field to indicate why the pipeline is being canceled. This can be also specified when - you cancel a task with `PipelineTask.cancel(reason="cancellation reason")`. + you cancel a task with `PipelineTask.cancel(reason="cancellation your +reason")`. - Added `include_prob_metrics` parameter to Whisper STT services to enable access to probability metrics from transcription results. @@ -60,9 +55,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fixed an issue that would cause wrong user/assistant context ordering when - using interruption strategies. - - Fixed an issue where the `SmallWebRTCRequest` dataclass in runner would scrub arbitrary request data from client due to camelCase typing. This fixes data passthrough for JS clients where `APIRequest` is used. diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 791de0f34..5db02c856 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -11,7 +11,6 @@ including data frames, system frames, and control frames for audio, video, text, and LLM processing. """ -import asyncio from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, @@ -860,13 +859,9 @@ class InterruptionFrame(SystemFrame): speaking (i.e. is interrupting). This is similar to UserStartedSpeakingFrame except that it should be pushed concurrently with other frames (so the order is not guaranteed). - - Parameters: - finished_event: If not None, the event will be set when the frame - reaches the end of the pipeline. """ - finished_event: Optional[asyncio.Event] = None + pass @dataclass @@ -1427,13 +1422,9 @@ class InterruptionTaskFrame(TaskFrame): same actions as if the user interrupted except that the UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated. This frame should be pushed upstream. - - Parameters: - finished_event: If not None, the event will be set when the generated - InterruptionFrame reaches the end of the pipeline. """ - finished_event: Optional[asyncio.Event] = None + pass @dataclass diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index ad51db0ca..90976b52c 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -745,7 +745,7 @@ class PipelineTask(BasePipelineTask): # pipeline. This is in case the push task is blocked waiting for a # pipeline-ending frame to finish traversing the pipeline. logger.debug(f"{self}: received interruption task frame {frame}") - await self._pipeline.queue_frame(InterruptionFrame(finished_event=frame.finished_event)) + await self._pipeline.queue_frame(InterruptionFrame()) elif isinstance(frame, ErrorFrame): await self._call_event_handler("on_pipeline_error", frame) if frame.fatal: @@ -786,10 +786,6 @@ class PipelineTask(BasePipelineTask): self._pipeline_end_event.set() elif isinstance(frame, HeartbeatFrame): await self._heartbeat_queue.put(frame) - elif isinstance(frame, InterruptionFrame) and frame.finished_event: - # This should unblock any code waiting for the interruption to - # complete. - frame.finished_event.set() async def _heartbeat_push_handler(self): """Push heartbeat frames at regular intervals.""" diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 225629ad9..1ca3333b5 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -93,7 +93,7 @@ class FrameProcessorQueue(asyncio.PriorityQueue): self.__high_counter = 0 self.__low_counter = 0 - async def put(self, item: Tuple[Frame, FrameDirection, Optional[FrameCallback]]): + async def put(self, item: Tuple[Frame, FrameDirection, FrameCallback]): """Put an item into the priority queue. System frames (`SystemFrame`) have higher priority than any other @@ -228,9 +228,11 @@ class FrameProcessor(BaseObject): # To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream. # Then we wait for the corresponding `InterruptionFrame` to travel from - # start to end of the pipeline. When it reaches the end we will be - # notified through the assigned event. + # the start of the pipeline back to the processor that sent the + # `InterruptionTaskFrame`. This wait is handled using the following + # event. self._wait_for_interruption = False + self._wait_interruption_event = asyncio.Event() # Frame processor events. self._register_event_handler("on_before_process_frame", sync=True) @@ -565,6 +567,10 @@ class FrameProcessor(BaseObject): if self._cancelling: return + # If we are waiting for an interruption we will bypass all queued system + # frames and we will process the frame right away. This is because a + # previous system frame might be waiting for the interruption frame and + # it's blocking the input task. if self._wait_for_interruption and isinstance(frame, InterruptionFrame): await self.__process_frame(frame, direction, callback) return @@ -655,27 +661,31 @@ class FrameProcessor(BaseObject): await self._call_event_handler("on_after_push_frame", frame) + # If we are waiting for an interruption and we get an interruption, then + # we can unblock `push_interruption_task_frame_and_wait()`. + if self._wait_for_interruption and isinstance(frame, InterruptionFrame): + self._wait_interruption_event.set() + async def push_interruption_task_frame_and_wait(self): - """Interrupt the pipeline and wait for the interruption to complete. - - This function sends an `InterruptionTaskFrame` upstream with an - associated asyncio event. It then waits for the generated - `InterruptionFrame` to reach the end of the pipeline where the event - will be set. + """Push an interruption task frame upstream and wait for the interruption. + This function sends an `InterruptionTaskFrame` upstream to the pipeline + task and waits to receive the corresponding `InterruptionFrame`. When + the function finishes it is guaranteed that the `InterruptionFrame` has + been pushed downstream. """ self._wait_for_interruption = True - finished_event = asyncio.Event() + await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM) - await self.push_frame( - InterruptionTaskFrame(finished_event=finished_event), FrameDirection.UPSTREAM - ) + # Wait for an `InterruptionFrame` to come to this processor and be + # pushed. Take a look at `push_frame()` to see how we first push the + # `InterruptionFrame` and then we set the event in order to maintain + # frame ordering. + await self._wait_interruption_event.wait() - # Wait for the event to be set. This event is set when the - # `InterruptionFrame` pushed by the pipeline task reaches the end of the - # pipeline. - await finished_event.wait() + # Clean the event. + self._wait_interruption_event.clear() self._wait_for_interruption = False diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py index cbab10de2..6196032a3 100644 --- a/tests/test_context_aggregators.py +++ b/tests/test_context_aggregators.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio import json import unittest from typing import Any, Optional @@ -31,23 +30,17 @@ from pipecat.frames.frames import ( SpeechControlParamsFrame, TextFrame, TranscriptionFrame, - TTSTextFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.task import PipelineParams -from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.aggregators.llm_response import ( LLMAssistantAggregatorParams, - LLMAssistantContextAggregator, LLMUserAggregatorParams, LLMUserContextAggregator, ) -from pipecat.processors.aggregators.llm_response_universal import ( - LLMAssistantAggregator, - LLMUserAggregator, -) +from pipecat.processors.aggregators.llm_response_universal import LLMAssistantAggregator from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -80,11 +73,8 @@ AGGREGATION_SLEEP = 0.15 class BaseTestUserContextAggregator: CONTEXT_CLASS = None # To be set in subclasses - CONTEXT_FRAME_CLASS = None # To be set in subclasses - USER_AGGREGATOR_CLASS = None # To be set in subclasses - USER_EXPECTED_CONTEXT_FRAMES = None - ASSISTANT_AGGREGATOR_CLASS = None # To be set in subclasses - ASSISTANT_EXPECTED_CONTEXT_FRAMES = None # To be set in subclasses + AGGREGATOR_CLASS = None # To be set in subclasses + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] def check_message_content(self, context: OpenAILLMContext, index: int, content: str): assert context.messages[index]["content"] == content @@ -96,12 +86,10 @@ class BaseTestUserContextAggregator: async def test_se(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()] expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] await run_test( @@ -112,12 +100,10 @@ class BaseTestUserContextAggregator: async def test_ste(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ UserStartedSpeakingFrame(), TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), @@ -126,7 +112,7 @@ class BaseTestUserContextAggregator: ] expected_down_frames = [ UserStartedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, UserStoppedSpeakingFrame, ] await run_test( @@ -138,12 +124,10 @@ class BaseTestUserContextAggregator: async def test_site(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ UserStartedSpeakingFrame(), InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""), @@ -153,7 +137,7 @@ class BaseTestUserContextAggregator: ] expected_down_frames = [ UserStartedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, UserStoppedSpeakingFrame, ] await run_test( @@ -165,12 +149,10 @@ class BaseTestUserContextAggregator: async def test_st1iest2e(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ UserStartedSpeakingFrame(), TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), @@ -186,7 +168,7 @@ class BaseTestUserContextAggregator: UserStartedSpeakingFrame, UserStoppedSpeakingFrame, UserStartedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, UserStoppedSpeakingFrame, ] await run_test( @@ -198,12 +180,10 @@ class BaseTestUserContextAggregator: async def test_siet(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ @@ -217,7 +197,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -228,12 +208,10 @@ class BaseTestUserContextAggregator: async def test_sieit(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ @@ -248,7 +226,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -259,12 +237,10 @@ class BaseTestUserContextAggregator: async def test_set(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ @@ -276,7 +252,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -287,12 +263,10 @@ class BaseTestUserContextAggregator: async def test_seit(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ @@ -305,7 +279,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -316,12 +290,10 @@ class BaseTestUserContextAggregator: async def test_st1et2(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ @@ -336,9 +308,9 @@ class BaseTestUserContextAggregator: expected_down_frames = [ SpeechControlParamsFrame, UserStartedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -350,12 +322,10 @@ class BaseTestUserContextAggregator: async def test_set1t2(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ @@ -368,7 +338,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -379,12 +349,10 @@ class BaseTestUserContextAggregator: async def test_siet1it2(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ @@ -400,7 +368,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -411,12 +379,10 @@ class BaseTestUserContextAggregator: async def test_t(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context ) # No aggregation timeout; this tests VAD emulation @@ -427,7 +393,7 @@ class BaseTestUserContextAggregator: ] expected_down_frames = [ SpeechControlParamsFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] @@ -441,12 +407,10 @@ class BaseTestUserContextAggregator: async def test_t_with_turn_analyzer(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(turn_emulated_vad_timeout=AGGREGATION_TIMEOUT) ) @@ -460,7 +424,7 @@ class BaseTestUserContextAggregator: ] expected_down_frames = [ SpeechControlParamsFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] @@ -474,12 +438,10 @@ class BaseTestUserContextAggregator: async def test_it(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context ) # No aggregation timeout; this tests VAD emulation frames_to_send = [ @@ -489,7 +451,7 @@ class BaseTestUserContextAggregator: TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), SleepFrame(sleep=AGGREGATION_SLEEP), ] - expected_down_frames = [SpeechControlParamsFrame, *self.USER_EXPECTED_CONTEXT_FRAMES] + expected_down_frames = [SpeechControlParamsFrame, *self.EXPECTED_CONTEXT_FRAMES] expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] await run_test( aggregator, @@ -501,12 +463,10 @@ class BaseTestUserContextAggregator: async def test_sie_delay_it(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) ) frames_to_send = [ @@ -522,7 +482,7 @@ class BaseTestUserContextAggregator: expected_down_frames = [ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -533,12 +493,7 @@ class BaseTestUserContextAggregator: async def test_min_words_interruption_strategy_one_word(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.CONTEXT_FRAME_CLASS is not None, "CONTEXT_FRAME_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) - - CONTEXT_FRAME_CLASS = self.CONTEXT_FRAME_CLASS + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" class ContextProcessor(FrameProcessor): def __init__(self): @@ -548,13 +503,13 @@ class BaseTestUserContextAggregator: async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if isinstance(frame, CONTEXT_FRAME_CLASS): + if isinstance(frame, OpenAILLMContextFrame): self.context_received = True await self.push_frame(frame, direction) context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) context_processor = ContextProcessor() pipeline = Pipeline([aggregator, context_processor]) @@ -582,12 +537,7 @@ class BaseTestUserContextAggregator: async def test_min_words_interruption_strategy_two_words(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.CONTEXT_FRAME_CLASS is not None, "CONTEXT_FRAME_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) - - CONTEXT_FRAME_CLASS = self.CONTEXT_FRAME_CLASS + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" class ContextProcessor(FrameProcessor): def __init__(self): @@ -597,7 +547,7 @@ class BaseTestUserContextAggregator: async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if isinstance(frame, CONTEXT_FRAME_CLASS): + if isinstance(frame, OpenAILLMContextFrame): self.context_received = True elif isinstance(frame, InterruptionFrame): self.context_received = False @@ -605,7 +555,7 @@ class BaseTestUserContextAggregator: await self.push_frame(frame, direction) context = self.CONTEXT_CLASS() - aggregator = self.USER_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) context_processor = ContextProcessor() pipeline = Pipeline([aggregator, context_processor]) @@ -622,7 +572,7 @@ class BaseTestUserContextAggregator: UserStartedSpeakingFrame, InterruptionFrame, UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( pipeline, @@ -638,77 +588,11 @@ class BaseTestUserContextAggregator: # interruption then we have an issue. assert context_processor.context_received - async def test_interruption_strategy_context_order(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.USER_AGGREGATOR_CLASS is not None, ( - "USER_AGGREGATOR_CLASS must be set in a subclass" - ) - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) - - class DelayedProcessor(FrameProcessor): - """Force a delay in interruption frames. - - This might give time to the assistant aggregator to update the - context before the user aggregator (which shouldn't really happen) - and reveal any issues in context ordering. - """ - - def __init__(self): - super().__init__() - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, InterruptionFrame): - await asyncio.sleep(0.3) - await self.push_frame(frame, direction) - - context = self.CONTEXT_CLASS() - user_aggregator = self.USER_AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=1.0) - ) - assistant_aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context) - pipeline = Pipeline([user_aggregator, DelayedProcessor(), assistant_aggregator]) - - frames_to_send = [ - # Aggregate assistant content. - BotStartedSpeakingFrame(), - LLMFullResponseStartFrame(), - TTSTextFrame(text="Hello, I'm your assistant"), - SleepFrame(), - # Interrupt the bot. Assistant content should be added first to the - # context, followed by user content. - UserStartedSpeakingFrame(), - TranscriptionFrame(text="Can you tell me", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - ] - expected_down_frames = [ - BotStartedSpeakingFrame, - UserStartedSpeakingFrame, - *self.ASSISTANT_EXPECTED_CONTEXT_FRAMES, - InterruptionFrame, - UserStoppedSpeakingFrame, - *self.USER_EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - pipeline, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - pipeline_params=PipelineParams( - interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)] - ), - ) - self.check_message_content(context, -1, "Can you tell me") - self.check_message_content(context, -2, "Hello, I'm your assistant") - class BaseTestAssistantContextAggregator: CONTEXT_CLASS = None # To be set in subclasses - USER_AGGREGATOR_CLASS = None # To be set in subclasses - ASSISTANT_AGGREGATOR_CLASS = None # To be set in subclasses - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [] # To be set in subclasses + AGGREGATOR_CLASS = None # To be set in subclasses + EXPECTED_CONTEXT_FRAMES = None # To be set in subclasses def create_assistant_aggregator_params( self, **kwargs @@ -728,12 +612,10 @@ class BaseTestAssistantContextAggregator: async def test_empty(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()] expected_down_frames = [] await run_test( @@ -744,18 +626,16 @@ class BaseTestAssistantContextAggregator: async def test_single_text(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ LLMFullResponseStartFrame(), TextFrame(text="Hello Pipecat!"), LLMFullResponseEndFrame(), ] - expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -765,12 +645,10 @@ class BaseTestAssistantContextAggregator: async def test_multiple_text(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.ASSISTANT_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=self.create_assistant_aggregator_params(expect_stripped_words=False) ) frames_to_send = [ @@ -781,7 +659,7 @@ class BaseTestAssistantContextAggregator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -791,12 +669,10 @@ class BaseTestAssistantContextAggregator: async def test_multiple_text_stripped(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ LLMFullResponseStartFrame(), TextFrame(text="Hello"), @@ -805,7 +681,7 @@ class BaseTestAssistantContextAggregator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -815,12 +691,10 @@ class BaseTestAssistantContextAggregator: async def test_multiple_llm_responses(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.ASSISTANT_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=self.create_assistant_aggregator_params(expect_stripped_words=False) ) frames_to_send = [ @@ -833,10 +707,7 @@ class BaseTestAssistantContextAggregator: TextFrame(text="you?"), LLMFullResponseEndFrame(), ] - expected_down_frames = [ - *self.ASSISTANT_EXPECTED_CONTEXT_FRAMES, - *self.ASSISTANT_EXPECTED_CONTEXT_FRAMES, - ] + expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES] await run_test( aggregator, frames_to_send=frames_to_send, @@ -847,12 +718,10 @@ class BaseTestAssistantContextAggregator: async def test_multiple_llm_responses_interruption(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.ASSISTANT_AGGREGATOR_CLASS( + aggregator = self.AGGREGATOR_CLASS( context, params=self.create_assistant_aggregator_params(expect_stripped_words=False) ) frames_to_send = [ @@ -868,9 +737,9 @@ class BaseTestAssistantContextAggregator: LLMFullResponseEndFrame(), ] expected_down_frames = [ - *self.ASSISTANT_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, InterruptionFrame, - *self.ASSISTANT_EXPECTED_CONTEXT_FRAMES, + *self.EXPECTED_CONTEXT_FRAMES, ] await run_test( aggregator, @@ -882,12 +751,10 @@ class BaseTestAssistantContextAggregator: async def test_function_call(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context = self.CONTEXT_CLASS() - aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ FunctionCallInProgressFrame( function_name="get_weather", @@ -913,9 +780,7 @@ class BaseTestAssistantContextAggregator: async def test_function_call_on_context_updated(self): assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.ASSISTANT_AGGREGATOR_CLASS is not None, ( - "ASSISTANT_AGGREGATOR_CLASS must be set in a subclass" - ) + assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" context_updated = False @@ -924,7 +789,7 @@ class BaseTestAssistantContextAggregator: context_updated = True context = self.CONTEXT_CLASS() - aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context) + aggregator = self.AGGREGATOR_CLASS(context) frames_to_send = [ FunctionCallInProgressFrame( function_name="get_weather", @@ -959,14 +824,7 @@ class BaseTestAssistantContextAggregator: class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase): CONTEXT_CLASS = OpenAILLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = LLMUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = LLMAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = LLMUserContextAggregator # @@ -978,14 +836,7 @@ class TestAnthropicUserContextAggregator( BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = AnthropicLLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = AnthropicUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = AnthropicAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = AnthropicUserContextAggregator def check_message_multi_content( self, context: OpenAILLMContext, content_index: int, index: int, content: str @@ -998,14 +849,8 @@ class TestAnthropicAssistantContextAggregator( BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = AnthropicLLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = AnthropicUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = AnthropicAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = AnthropicAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] def check_message_multi_content( self, context: OpenAILLMContext, content_index: int, index: int, content: str @@ -1026,14 +871,7 @@ class TestAWSBedrockUserContextAggregator( BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = AWSBedrockLLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = AWSBedrockUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = AWSBedrockUserContextAggregator def check_message_multi_content( self, context: OpenAILLMContext, content_index: int, index: int, content: str @@ -1046,14 +884,8 @@ class TestAWSBedrockAssistantContextAggregator( BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = AWSBedrockLLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = AWSBedrockUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] def check_message_multi_content( self, context: OpenAILLMContext, content_index: int, index: int, content: str @@ -1076,14 +908,7 @@ class TestGoogleUserContextAggregator( BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = GoogleLLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = GoogleUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = GoogleAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = GoogleUserContextAggregator def check_message_content(self, context: OpenAILLMContext, index: int, content: str): obj = context.messages[index].to_json_dict() @@ -1100,14 +925,8 @@ class TestGoogleAssistantContextAggregator( BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = GoogleLLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = GoogleUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = GoogleAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = GoogleAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] def check_message_content(self, context: OpenAILLMContext, index: int, content: str): obj = context.messages[index].to_json_dict() @@ -1133,53 +952,26 @@ class TestOpenAIUserContextAggregator( BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = OpenAILLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = OpenAIUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = OpenAIAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = OpenAIUserContextAggregator class TestOpenAIAssistantContextAggregator( BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase ): CONTEXT_CLASS = OpenAILLMContext - CONTEXT_FRAME_CLASS = OpenAILLMContextFrame - USER_AGGREGATOR_CLASS = OpenAIUserContextAggregator - USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = OpenAIAssistantContextAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [ - OpenAILLMContextFrame, - OpenAILLMContextAssistantTimestampFrame, - ] + AGGREGATOR_CLASS = OpenAIAssistantContextAggregator + EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] # # Universal # - - -class TestLLMUserAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase): - CONTEXT_CLASS = LLMContext - CONTEXT_FRAME_CLASS = LLMContextFrame - USER_AGGREGATOR_CLASS = LLMUserAggregator - USER_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = LLMAssistantAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame] - - class TestLLMAssistantAggregator( BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase ): - CONTEXT_CLASS = LLMContext - CONTEXT_FRAME_CLASS = LLMContextFrame - USER_AGGREGATOR_CLASS = LLMUserAggregator - USER_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame] - ASSISTANT_AGGREGATOR_CLASS = LLMAssistantAggregator - ASSISTANT_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame] + CONTEXT_CLASS = OpenAILLMContext + AGGREGATOR_CLASS = LLMAssistantAggregator + EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame] # Override to remove 'expect_stripped_words' parameter, which is deprecated # for LLMAssistantAggregator