Revert "fix interruption task frame context ordering"

This commit is contained in:
Aleix Conchillo Flaqué
2025-11-05 12:14:03 -08:00
committed by GitHub
parent 11b101e8a6
commit d844829538
5 changed files with 127 additions and 346 deletions

View File

@@ -9,12 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added an asyncio event `finished_event` field to `InterruptionFrame`. When
assigned, the asyncio event will be set when the frame reaches the end of the
pipeline. You can use this field to know when an interruption made it all the
way to the end of a pipeline. The field has been also added to
`InterruptionTaskFrame`.
- Added support for loading external observers. You can now register custom
pipeline observers by setting the `PIPECAT_OBSERVER_FILES` environment
variable. This variable should contain a colon-separated list of Python files
@@ -34,7 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `CancelFrame` and `CancelTaskFrame` have an optional `reason` field to
indicate why the pipeline is being canceled. This can be also specified when
you cancel a task with `PipelineTask.cancel(reason="cancellation reason")`.
you cancel a task with `PipelineTask.cancel(reason="cancellation your
reason")`.
- Added `include_prob_metrics` parameter to Whisper STT services to enable access
to probability metrics from transcription results.
@@ -60,9 +55,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed an issue that would cause wrong user/assistant context ordering when
using interruption strategies.
- Fixed an issue where the `SmallWebRTCRequest` dataclass in runner would scrub
arbitrary request data from client due to camelCase typing. This fixes data
passthrough for JS clients where `APIRequest` is used.

View File

@@ -11,7 +11,6 @@ including data frames, system frames, and control frames for audio, video, text,
and LLM processing.
"""
import asyncio
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
@@ -860,13 +859,9 @@ class InterruptionFrame(SystemFrame):
speaking (i.e. is interrupting). This is similar to
UserStartedSpeakingFrame except that it should be pushed concurrently
with other frames (so the order is not guaranteed).
Parameters:
finished_event: If not None, the event will be set when the frame
reaches the end of the pipeline.
"""
finished_event: Optional[asyncio.Event] = None
pass
@dataclass
@@ -1427,13 +1422,9 @@ class InterruptionTaskFrame(TaskFrame):
same actions as if the user interrupted except that the
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
This frame should be pushed upstream.
Parameters:
finished_event: If not None, the event will be set when the generated
InterruptionFrame reaches the end of the pipeline.
"""
finished_event: Optional[asyncio.Event] = None
pass
@dataclass

View File

@@ -745,7 +745,7 @@ class PipelineTask(BasePipelineTask):
# pipeline. This is in case the push task is blocked waiting for a
# pipeline-ending frame to finish traversing the pipeline.
logger.debug(f"{self}: received interruption task frame {frame}")
await self._pipeline.queue_frame(InterruptionFrame(finished_event=frame.finished_event))
await self._pipeline.queue_frame(InterruptionFrame())
elif isinstance(frame, ErrorFrame):
await self._call_event_handler("on_pipeline_error", frame)
if frame.fatal:
@@ -786,10 +786,6 @@ class PipelineTask(BasePipelineTask):
self._pipeline_end_event.set()
elif isinstance(frame, HeartbeatFrame):
await self._heartbeat_queue.put(frame)
elif isinstance(frame, InterruptionFrame) and frame.finished_event:
# This should unblock any code waiting for the interruption to
# complete.
frame.finished_event.set()
async def _heartbeat_push_handler(self):
"""Push heartbeat frames at regular intervals."""

View File

@@ -93,7 +93,7 @@ class FrameProcessorQueue(asyncio.PriorityQueue):
self.__high_counter = 0
self.__low_counter = 0
async def put(self, item: Tuple[Frame, FrameDirection, Optional[FrameCallback]]):
async def put(self, item: Tuple[Frame, FrameDirection, FrameCallback]):
"""Put an item into the priority queue.
System frames (`SystemFrame`) have higher priority than any other
@@ -228,9 +228,11 @@ class FrameProcessor(BaseObject):
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
# Then we wait for the corresponding `InterruptionFrame` to travel from
# start to end of the pipeline. When it reaches the end we will be
# notified through the assigned event.
# the start of the pipeline back to the processor that sent the
# `InterruptionTaskFrame`. This wait is handled using the following
# event.
self._wait_for_interruption = False
self._wait_interruption_event = asyncio.Event()
# Frame processor events.
self._register_event_handler("on_before_process_frame", sync=True)
@@ -565,6 +567,10 @@ class FrameProcessor(BaseObject):
if self._cancelling:
return
# If we are waiting for an interruption we will bypass all queued system
# frames and we will process the frame right away. This is because a
# previous system frame might be waiting for the interruption frame and
# it's blocking the input task.
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
await self.__process_frame(frame, direction, callback)
return
@@ -655,27 +661,31 @@ class FrameProcessor(BaseObject):
await self._call_event_handler("on_after_push_frame", frame)
# If we are waiting for an interruption and we get an interruption, then
# we can unblock `push_interruption_task_frame_and_wait()`.
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
self._wait_interruption_event.set()
async def push_interruption_task_frame_and_wait(self):
"""Interrupt the pipeline and wait for the interruption to complete.
This function sends an `InterruptionTaskFrame` upstream with an
associated asyncio event. It then waits for the generated
`InterruptionFrame` to reach the end of the pipeline where the event
will be set.
"""Push an interruption task frame upstream and wait for the interruption.
This function sends an `InterruptionTaskFrame` upstream to the pipeline
task and waits to receive the corresponding `InterruptionFrame`. When
the function finishes it is guaranteed that the `InterruptionFrame` has
been pushed downstream.
"""
self._wait_for_interruption = True
finished_event = asyncio.Event()
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
await self.push_frame(
InterruptionTaskFrame(finished_event=finished_event), FrameDirection.UPSTREAM
)
# Wait for an `InterruptionFrame` to come to this processor and be
# pushed. Take a look at `push_frame()` to see how we first push the
# `InterruptionFrame` and then we set the event in order to maintain
# frame ordering.
await self._wait_interruption_event.wait()
# Wait for the event to be set. This event is set when the
# `InterruptionFrame` pushed by the pipeline task reaches the end of the
# pipeline.
await finished_event.wait()
# Clean the event.
self._wait_interruption_event.clear()
self._wait_for_interruption = False

View File

@@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import json
import unittest
from typing import Any, Optional
@@ -31,23 +30,17 @@ from pipecat.frames.frames import (
SpeechControlParamsFrame,
TextFrame,
TranscriptionFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import (
LLMAssistantAggregatorParams,
LLMAssistantContextAggregator,
LLMUserAggregatorParams,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregator,
LLMUserAggregator,
)
from pipecat.processors.aggregators.llm_response_universal import LLMAssistantAggregator
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
@@ -80,11 +73,8 @@ AGGREGATION_SLEEP = 0.15
class BaseTestUserContextAggregator:
CONTEXT_CLASS = None # To be set in subclasses
CONTEXT_FRAME_CLASS = None # To be set in subclasses
USER_AGGREGATOR_CLASS = None # To be set in subclasses
USER_EXPECTED_CONTEXT_FRAMES = None
ASSISTANT_AGGREGATOR_CLASS = None # To be set in subclasses
ASSISTANT_EXPECTED_CONTEXT_FRAMES = None # To be set in subclasses
AGGREGATOR_CLASS = None # To be set in subclasses
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
assert context.messages[index]["content"] == content
@@ -96,12 +86,10 @@ class BaseTestUserContextAggregator:
async def test_se(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()]
expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
await run_test(
@@ -112,12 +100,10 @@ class BaseTestUserContextAggregator:
async def test_ste(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
@@ -126,7 +112,7 @@ class BaseTestUserContextAggregator:
]
expected_down_frames = [
UserStartedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
UserStoppedSpeakingFrame,
]
await run_test(
@@ -138,12 +124,10 @@ class BaseTestUserContextAggregator:
async def test_site(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
UserStartedSpeakingFrame(),
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""),
@@ -153,7 +137,7 @@ class BaseTestUserContextAggregator:
]
expected_down_frames = [
UserStartedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
UserStoppedSpeakingFrame,
]
await run_test(
@@ -165,12 +149,10 @@ class BaseTestUserContextAggregator:
async def test_st1iest2e(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
@@ -186,7 +168,7 @@ class BaseTestUserContextAggregator:
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
UserStartedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
UserStoppedSpeakingFrame,
]
await run_test(
@@ -198,12 +180,10 @@ class BaseTestUserContextAggregator:
async def test_siet(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
@@ -217,7 +197,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -228,12 +208,10 @@ class BaseTestUserContextAggregator:
async def test_sieit(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
@@ -248,7 +226,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -259,12 +237,10 @@ class BaseTestUserContextAggregator:
async def test_set(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
@@ -276,7 +252,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -287,12 +263,10 @@ class BaseTestUserContextAggregator:
async def test_seit(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
@@ -305,7 +279,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -316,12 +290,10 @@ class BaseTestUserContextAggregator:
async def test_st1et2(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
@@ -336,9 +308,9 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
SpeechControlParamsFrame,
UserStartedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -350,12 +322,10 @@ class BaseTestUserContextAggregator:
async def test_set1t2(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
@@ -368,7 +338,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -379,12 +349,10 @@ class BaseTestUserContextAggregator:
async def test_siet1it2(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
@@ -400,7 +368,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -411,12 +379,10 @@ class BaseTestUserContextAggregator:
async def test_t(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context
) # No aggregation timeout; this tests VAD emulation
@@ -427,7 +393,7 @@ class BaseTestUserContextAggregator:
]
expected_down_frames = [
SpeechControlParamsFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
@@ -441,12 +407,10 @@ class BaseTestUserContextAggregator:
async def test_t_with_turn_analyzer(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(turn_emulated_vad_timeout=AGGREGATION_TIMEOUT)
)
@@ -460,7 +424,7 @@ class BaseTestUserContextAggregator:
]
expected_down_frames = [
SpeechControlParamsFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
@@ -474,12 +438,10 @@ class BaseTestUserContextAggregator:
async def test_it(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context
) # No aggregation timeout; this tests VAD emulation
frames_to_send = [
@@ -489,7 +451,7 @@ class BaseTestUserContextAggregator:
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
SleepFrame(sleep=AGGREGATION_SLEEP),
]
expected_down_frames = [SpeechControlParamsFrame, *self.USER_EXPECTED_CONTEXT_FRAMES]
expected_down_frames = [SpeechControlParamsFrame, *self.EXPECTED_CONTEXT_FRAMES]
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
await run_test(
aggregator,
@@ -501,12 +463,10 @@ class BaseTestUserContextAggregator:
async def test_sie_delay_it(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
)
frames_to_send = [
@@ -522,7 +482,7 @@ class BaseTestUserContextAggregator:
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -533,12 +493,7 @@ class BaseTestUserContextAggregator:
async def test_min_words_interruption_strategy_one_word(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.CONTEXT_FRAME_CLASS is not None, "CONTEXT_FRAME_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
CONTEXT_FRAME_CLASS = self.CONTEXT_FRAME_CLASS
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
class ContextProcessor(FrameProcessor):
def __init__(self):
@@ -548,13 +503,13 @@ class BaseTestUserContextAggregator:
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, CONTEXT_FRAME_CLASS):
if isinstance(frame, OpenAILLMContextFrame):
self.context_received = True
await self.push_frame(frame, direction)
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
context_processor = ContextProcessor()
pipeline = Pipeline([aggregator, context_processor])
@@ -582,12 +537,7 @@ class BaseTestUserContextAggregator:
async def test_min_words_interruption_strategy_two_words(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.CONTEXT_FRAME_CLASS is not None, "CONTEXT_FRAME_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
CONTEXT_FRAME_CLASS = self.CONTEXT_FRAME_CLASS
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
class ContextProcessor(FrameProcessor):
def __init__(self):
@@ -597,7 +547,7 @@ class BaseTestUserContextAggregator:
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, CONTEXT_FRAME_CLASS):
if isinstance(frame, OpenAILLMContextFrame):
self.context_received = True
elif isinstance(frame, InterruptionFrame):
self.context_received = False
@@ -605,7 +555,7 @@ class BaseTestUserContextAggregator:
await self.push_frame(frame, direction)
context = self.CONTEXT_CLASS()
aggregator = self.USER_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
context_processor = ContextProcessor()
pipeline = Pipeline([aggregator, context_processor])
@@ -622,7 +572,7 @@ class BaseTestUserContextAggregator:
UserStartedSpeakingFrame,
InterruptionFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
pipeline,
@@ -638,77 +588,11 @@ class BaseTestUserContextAggregator:
# interruption then we have an issue.
assert context_processor.context_received
async def test_interruption_strategy_context_order(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.USER_AGGREGATOR_CLASS is not None, (
"USER_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
class DelayedProcessor(FrameProcessor):
"""Force a delay in interruption frames.
This might give time to the assistant aggregator to update the
context before the user aggregator (which shouldn't really happen)
and reveal any issues in context ordering.
"""
def __init__(self):
super().__init__()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
await asyncio.sleep(0.3)
await self.push_frame(frame, direction)
context = self.CONTEXT_CLASS()
user_aggregator = self.USER_AGGREGATOR_CLASS(
context, params=LLMUserAggregatorParams(aggregation_timeout=1.0)
)
assistant_aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
pipeline = Pipeline([user_aggregator, DelayedProcessor(), assistant_aggregator])
frames_to_send = [
# Aggregate assistant content.
BotStartedSpeakingFrame(),
LLMFullResponseStartFrame(),
TTSTextFrame(text="Hello, I'm your assistant"),
SleepFrame(),
# Interrupt the bot. Assistant content should be added first to the
# context, followed by user content.
UserStartedSpeakingFrame(),
TranscriptionFrame(text="Can you tell me", user_id="cat", timestamp=""),
SleepFrame(),
UserStoppedSpeakingFrame(),
]
expected_down_frames = [
BotStartedSpeakingFrame,
UserStartedSpeakingFrame,
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
InterruptionFrame,
UserStoppedSpeakingFrame,
*self.USER_EXPECTED_CONTEXT_FRAMES,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
pipeline_params=PipelineParams(
interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)]
),
)
self.check_message_content(context, -1, "Can you tell me")
self.check_message_content(context, -2, "Hello, I'm your assistant")
class BaseTestAssistantContextAggregator:
CONTEXT_CLASS = None # To be set in subclasses
USER_AGGREGATOR_CLASS = None # To be set in subclasses
ASSISTANT_AGGREGATOR_CLASS = None # To be set in subclasses
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [] # To be set in subclasses
AGGREGATOR_CLASS = None # To be set in subclasses
EXPECTED_CONTEXT_FRAMES = None # To be set in subclasses
def create_assistant_aggregator_params(
self, **kwargs
@@ -728,12 +612,10 @@ class BaseTestAssistantContextAggregator:
async def test_empty(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
expected_down_frames = []
await run_test(
@@ -744,18 +626,16 @@ class BaseTestAssistantContextAggregator:
async def test_single_text(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello Pipecat!"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -765,12 +645,10 @@ class BaseTestAssistantContextAggregator:
async def test_multiple_text(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=self.create_assistant_aggregator_params(expect_stripped_words=False)
)
frames_to_send = [
@@ -781,7 +659,7 @@ class BaseTestAssistantContextAggregator:
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -791,12 +669,10 @@ class BaseTestAssistantContextAggregator:
async def test_multiple_text_stripped(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
LLMFullResponseStartFrame(),
TextFrame(text="Hello"),
@@ -805,7 +681,7 @@ class BaseTestAssistantContextAggregator:
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -815,12 +691,10 @@ class BaseTestAssistantContextAggregator:
async def test_multiple_llm_responses(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=self.create_assistant_aggregator_params(expect_stripped_words=False)
)
frames_to_send = [
@@ -833,10 +707,7 @@ class BaseTestAssistantContextAggregator:
TextFrame(text="you?"),
LLMFullResponseEndFrame(),
]
expected_down_frames = [
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
@@ -847,12 +718,10 @@ class BaseTestAssistantContextAggregator:
async def test_multiple_llm_responses_interruption(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(
aggregator = self.AGGREGATOR_CLASS(
context, params=self.create_assistant_aggregator_params(expect_stripped_words=False)
)
frames_to_send = [
@@ -868,9 +737,9 @@ class BaseTestAssistantContextAggregator:
LLMFullResponseEndFrame(),
]
expected_down_frames = [
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
InterruptionFrame,
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
*self.EXPECTED_CONTEXT_FRAMES,
]
await run_test(
aggregator,
@@ -882,12 +751,10 @@ class BaseTestAssistantContextAggregator:
async def test_function_call(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
FunctionCallInProgressFrame(
function_name="get_weather",
@@ -913,9 +780,7 @@ class BaseTestAssistantContextAggregator:
async def test_function_call_on_context_updated(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
)
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context_updated = False
@@ -924,7 +789,7 @@ class BaseTestAssistantContextAggregator:
context_updated = True
context = self.CONTEXT_CLASS()
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
aggregator = self.AGGREGATOR_CLASS(context)
frames_to_send = [
FunctionCallInProgressFrame(
function_name="get_weather",
@@ -959,14 +824,7 @@ class BaseTestAssistantContextAggregator:
class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase):
CONTEXT_CLASS = OpenAILLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = LLMUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = LLMAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = LLMUserContextAggregator
#
@@ -978,14 +836,7 @@ class TestAnthropicUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AnthropicLLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = AnthropicUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = AnthropicUserContextAggregator
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
@@ -998,14 +849,8 @@ class TestAnthropicAssistantContextAggregator(
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AnthropicLLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = AnthropicUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
@@ -1026,14 +871,7 @@ class TestAWSBedrockUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AWSBedrockLLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
@@ -1046,14 +884,8 @@ class TestAWSBedrockAssistantContextAggregator(
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = AWSBedrockLLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_multi_content(
self, context: OpenAILLMContext, content_index: int, index: int, content: str
@@ -1076,14 +908,7 @@ class TestGoogleUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = GoogleLLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = GoogleUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = GoogleAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = GoogleUserContextAggregator
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = context.messages[index].to_json_dict()
@@ -1100,14 +925,8 @@ class TestGoogleAssistantContextAggregator(
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = GoogleLLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = GoogleUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = GoogleAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = GoogleAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
obj = context.messages[index].to_json_dict()
@@ -1133,53 +952,26 @@ class TestOpenAIUserContextAggregator(
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = OpenAIUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = OpenAIUserContextAggregator
class TestOpenAIAssistantContextAggregator(
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = OpenAILLMContext
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
USER_AGGREGATOR_CLASS = OpenAIUserContextAggregator
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
OpenAILLMContextFrame,
OpenAILLMContextAssistantTimestampFrame,
]
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
#
# Universal
#
class TestLLMUserAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase):
CONTEXT_CLASS = LLMContext
CONTEXT_FRAME_CLASS = LLMContextFrame
USER_AGGREGATOR_CLASS = LLMUserAggregator
USER_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = LLMAssistantAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame]
class TestLLMAssistantAggregator(
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
):
CONTEXT_CLASS = LLMContext
CONTEXT_FRAME_CLASS = LLMContextFrame
USER_AGGREGATOR_CLASS = LLMUserAggregator
USER_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame]
ASSISTANT_AGGREGATOR_CLASS = LLMAssistantAggregator
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame]
CONTEXT_CLASS = OpenAILLMContext
AGGREGATOR_CLASS = LLMAssistantAggregator
EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame]
# Override to remove 'expect_stripped_words' parameter, which is deprecated
# for LLMAssistantAggregator