Revert "fix interruption task frame context ordering"
This commit is contained in:
committed by
GitHub
parent
11b101e8a6
commit
d844829538
12
CHANGELOG.md
12
CHANGELOG.md
@@ -9,12 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added an asyncio event `finished_event` field to `InterruptionFrame`. When
|
||||
assigned, the asyncio event will be set when the frame reaches the end of the
|
||||
pipeline. You can use this field to know when an interruption made it all the
|
||||
way to the end of a pipeline. The field has been also added to
|
||||
`InterruptionTaskFrame`.
|
||||
|
||||
- Added support for loading external observers. You can now register custom
|
||||
pipeline observers by setting the `PIPECAT_OBSERVER_FILES` environment
|
||||
variable. This variable should contain a colon-separated list of Python files
|
||||
@@ -34,7 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- `CancelFrame` and `CancelTaskFrame` have an optional `reason` field to
|
||||
indicate why the pipeline is being canceled. This can be also specified when
|
||||
you cancel a task with `PipelineTask.cancel(reason="cancellation reason")`.
|
||||
you cancel a task with `PipelineTask.cancel(reason="cancellation your
|
||||
reason")`.
|
||||
|
||||
- Added `include_prob_metrics` parameter to Whisper STT services to enable access
|
||||
to probability metrics from transcription results.
|
||||
@@ -60,9 +55,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue that would cause wrong user/assistant context ordering when
|
||||
using interruption strategies.
|
||||
|
||||
- Fixed an issue where the `SmallWebRTCRequest` dataclass in runner would scrub
|
||||
arbitrary request data from client due to camelCase typing. This fixes data
|
||||
passthrough for JS clients where `APIRequest` is used.
|
||||
|
||||
@@ -11,7 +11,6 @@ including data frames, system frames, and control frames for audio, video, text,
|
||||
and LLM processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -860,13 +859,9 @@ class InterruptionFrame(SystemFrame):
|
||||
speaking (i.e. is interrupting). This is similar to
|
||||
UserStartedSpeakingFrame except that it should be pushed concurrently
|
||||
with other frames (so the order is not guaranteed).
|
||||
|
||||
Parameters:
|
||||
finished_event: If not None, the event will be set when the frame
|
||||
reaches the end of the pipeline.
|
||||
"""
|
||||
|
||||
finished_event: Optional[asyncio.Event] = None
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1427,13 +1422,9 @@ class InterruptionTaskFrame(TaskFrame):
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
This frame should be pushed upstream.
|
||||
|
||||
Parameters:
|
||||
finished_event: If not None, the event will be set when the generated
|
||||
InterruptionFrame reaches the end of the pipeline.
|
||||
"""
|
||||
|
||||
finished_event: Optional[asyncio.Event] = None
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -745,7 +745,7 @@ class PipelineTask(BasePipelineTask):
|
||||
# pipeline. This is in case the push task is blocked waiting for a
|
||||
# pipeline-ending frame to finish traversing the pipeline.
|
||||
logger.debug(f"{self}: received interruption task frame {frame}")
|
||||
await self._pipeline.queue_frame(InterruptionFrame(finished_event=frame.finished_event))
|
||||
await self._pipeline.queue_frame(InterruptionFrame())
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
await self._call_event_handler("on_pipeline_error", frame)
|
||||
if frame.fatal:
|
||||
@@ -786,10 +786,6 @@ class PipelineTask(BasePipelineTask):
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
await self._heartbeat_queue.put(frame)
|
||||
elif isinstance(frame, InterruptionFrame) and frame.finished_event:
|
||||
# This should unblock any code waiting for the interruption to
|
||||
# complete.
|
||||
frame.finished_event.set()
|
||||
|
||||
async def _heartbeat_push_handler(self):
|
||||
"""Push heartbeat frames at regular intervals."""
|
||||
|
||||
@@ -93,7 +93,7 @@ class FrameProcessorQueue(asyncio.PriorityQueue):
|
||||
self.__high_counter = 0
|
||||
self.__low_counter = 0
|
||||
|
||||
async def put(self, item: Tuple[Frame, FrameDirection, Optional[FrameCallback]]):
|
||||
async def put(self, item: Tuple[Frame, FrameDirection, FrameCallback]):
|
||||
"""Put an item into the priority queue.
|
||||
|
||||
System frames (`SystemFrame`) have higher priority than any other
|
||||
@@ -228,9 +228,11 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
|
||||
# Then we wait for the corresponding `InterruptionFrame` to travel from
|
||||
# start to end of the pipeline. When it reaches the end we will be
|
||||
# notified through the assigned event.
|
||||
# the start of the pipeline back to the processor that sent the
|
||||
# `InterruptionTaskFrame`. This wait is handled using the following
|
||||
# event.
|
||||
self._wait_for_interruption = False
|
||||
self._wait_interruption_event = asyncio.Event()
|
||||
|
||||
# Frame processor events.
|
||||
self._register_event_handler("on_before_process_frame", sync=True)
|
||||
@@ -565,6 +567,10 @@ class FrameProcessor(BaseObject):
|
||||
if self._cancelling:
|
||||
return
|
||||
|
||||
# If we are waiting for an interruption we will bypass all queued system
|
||||
# frames and we will process the frame right away. This is because a
|
||||
# previous system frame might be waiting for the interruption frame and
|
||||
# it's blocking the input task.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
return
|
||||
@@ -655,27 +661,31 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self._call_event_handler("on_after_push_frame", frame)
|
||||
|
||||
# If we are waiting for an interruption and we get an interruption, then
|
||||
# we can unblock `push_interruption_task_frame_and_wait()`.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
self._wait_interruption_event.set()
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self):
|
||||
"""Interrupt the pipeline and wait for the interruption to complete.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream with an
|
||||
associated asyncio event. It then waits for the generated
|
||||
`InterruptionFrame` to reach the end of the pipeline where the event
|
||||
will be set.
|
||||
"""Push an interruption task frame upstream and wait for the interruption.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream to the pipeline
|
||||
task and waits to receive the corresponding `InterruptionFrame`. When
|
||||
the function finishes it is guaranteed that the `InterruptionFrame` has
|
||||
been pushed downstream.
|
||||
"""
|
||||
self._wait_for_interruption = True
|
||||
|
||||
finished_event = asyncio.Event()
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
await self.push_frame(
|
||||
InterruptionTaskFrame(finished_event=finished_event), FrameDirection.UPSTREAM
|
||||
)
|
||||
# Wait for an `InterruptionFrame` to come to this processor and be
|
||||
# pushed. Take a look at `push_frame()` to see how we first push the
|
||||
# `InterruptionFrame` and then we set the event in order to maintain
|
||||
# frame ordering.
|
||||
await self._wait_interruption_event.wait()
|
||||
|
||||
# Wait for the event to be set. This event is set when the
|
||||
# `InterruptionFrame` pushed by the pipeline task reaches the end of the
|
||||
# pipeline.
|
||||
await finished_event.wait()
|
||||
# Clean the event.
|
||||
self._wait_interruption_event.clear()
|
||||
|
||||
self._wait_for_interruption = False
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from typing import Any, Optional
|
||||
@@ -31,23 +30,17 @@ from pipecat.frames.frames import (
|
||||
SpeechControlParamsFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregator,
|
||||
LLMUserAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMAssistantAggregator
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -80,11 +73,8 @@ AGGREGATION_SLEEP = 0.15
|
||||
|
||||
class BaseTestUserContextAggregator:
|
||||
CONTEXT_CLASS = None # To be set in subclasses
|
||||
CONTEXT_FRAME_CLASS = None # To be set in subclasses
|
||||
USER_AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
USER_EXPECTED_CONTEXT_FRAMES = None
|
||||
ASSISTANT_AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = None # To be set in subclasses
|
||||
AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
assert context.messages[index]["content"] == content
|
||||
@@ -96,12 +86,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_se(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()]
|
||||
expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
@@ -112,12 +100,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_ste(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""),
|
||||
@@ -126,7 +112,7 @@ class BaseTestUserContextAggregator:
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
await run_test(
|
||||
@@ -138,12 +124,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_site(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""),
|
||||
@@ -153,7 +137,7 @@ class BaseTestUserContextAggregator:
|
||||
]
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
await run_test(
|
||||
@@ -165,12 +149,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_st1iest2e(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
@@ -186,7 +168,7 @@ class BaseTestUserContextAggregator:
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
await run_test(
|
||||
@@ -198,12 +180,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_siet(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -217,7 +197,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -228,12 +208,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_sieit(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -248,7 +226,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -259,12 +237,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_set(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -276,7 +252,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -287,12 +263,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_seit(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -305,7 +279,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -316,12 +290,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_st1et2(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -336,9 +308,9 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
SpeechControlParamsFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -350,12 +322,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_set1t2(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -368,7 +338,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -379,12 +349,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_siet1it2(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -400,7 +368,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -411,12 +379,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_t(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context
|
||||
) # No aggregation timeout; this tests VAD emulation
|
||||
|
||||
@@ -427,7 +393,7 @@ class BaseTestUserContextAggregator:
|
||||
]
|
||||
expected_down_frames = [
|
||||
SpeechControlParamsFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
|
||||
@@ -441,12 +407,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_t_with_turn_analyzer(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(turn_emulated_vad_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
|
||||
@@ -460,7 +424,7 @@ class BaseTestUserContextAggregator:
|
||||
]
|
||||
expected_down_frames = [
|
||||
SpeechControlParamsFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
|
||||
@@ -474,12 +438,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_it(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context
|
||||
) # No aggregation timeout; this tests VAD emulation
|
||||
frames_to_send = [
|
||||
@@ -489,7 +451,7 @@ class BaseTestUserContextAggregator:
|
||||
TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""),
|
||||
SleepFrame(sleep=AGGREGATION_SLEEP),
|
||||
]
|
||||
expected_down_frames = [SpeechControlParamsFrame, *self.USER_EXPECTED_CONTEXT_FRAMES]
|
||||
expected_down_frames = [SpeechControlParamsFrame, *self.EXPECTED_CONTEXT_FRAMES]
|
||||
expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -501,12 +463,10 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_sie_delay_it(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -522,7 +482,7 @@ class BaseTestUserContextAggregator:
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -533,12 +493,7 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_min_words_interruption_strategy_one_word(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.CONTEXT_FRAME_CLASS is not None, "CONTEXT_FRAME_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
|
||||
CONTEXT_FRAME_CLASS = self.CONTEXT_FRAME_CLASS
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
class ContextProcessor(FrameProcessor):
|
||||
def __init__(self):
|
||||
@@ -548,13 +503,13 @@ class BaseTestUserContextAggregator:
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, CONTEXT_FRAME_CLASS):
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
self.context_received = True
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
context_processor = ContextProcessor()
|
||||
pipeline = Pipeline([aggregator, context_processor])
|
||||
|
||||
@@ -582,12 +537,7 @@ class BaseTestUserContextAggregator:
|
||||
|
||||
async def test_min_words_interruption_strategy_two_words(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.CONTEXT_FRAME_CLASS is not None, "CONTEXT_FRAME_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
|
||||
CONTEXT_FRAME_CLASS = self.CONTEXT_FRAME_CLASS
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
class ContextProcessor(FrameProcessor):
|
||||
def __init__(self):
|
||||
@@ -597,7 +547,7 @@ class BaseTestUserContextAggregator:
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, CONTEXT_FRAME_CLASS):
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
self.context_received = True
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
self.context_received = False
|
||||
@@ -605,7 +555,7 @@ class BaseTestUserContextAggregator:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.USER_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
context_processor = ContextProcessor()
|
||||
pipeline = Pipeline([aggregator, context_processor])
|
||||
|
||||
@@ -622,7 +572,7 @@ class BaseTestUserContextAggregator:
|
||||
UserStartedSpeakingFrame,
|
||||
InterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
@@ -638,77 +588,11 @@ class BaseTestUserContextAggregator:
|
||||
# interruption then we have an issue.
|
||||
assert context_processor.context_received
|
||||
|
||||
async def test_interruption_strategy_context_order(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.USER_AGGREGATOR_CLASS is not None, (
|
||||
"USER_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
|
||||
class DelayedProcessor(FrameProcessor):
|
||||
"""Force a delay in interruption frames.
|
||||
|
||||
This might give time to the assistant aggregator to update the
|
||||
context before the user aggregator (which shouldn't really happen)
|
||||
and reveal any issues in context ordering.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await asyncio.sleep(0.3)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
user_aggregator = self.USER_AGGREGATOR_CLASS(
|
||||
context, params=LLMUserAggregatorParams(aggregation_timeout=1.0)
|
||||
)
|
||||
assistant_aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
|
||||
pipeline = Pipeline([user_aggregator, DelayedProcessor(), assistant_aggregator])
|
||||
|
||||
frames_to_send = [
|
||||
# Aggregate assistant content.
|
||||
BotStartedSpeakingFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
TTSTextFrame(text="Hello, I'm your assistant"),
|
||||
SleepFrame(),
|
||||
# Interrupt the bot. Assistant content should be added first to the
|
||||
# context, followed by user content.
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Can you tell me", user_id="cat", timestamp=""),
|
||||
SleepFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
|
||||
InterruptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
*self.USER_EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
pipeline_params=PipelineParams(
|
||||
interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)]
|
||||
),
|
||||
)
|
||||
self.check_message_content(context, -1, "Can you tell me")
|
||||
self.check_message_content(context, -2, "Hello, I'm your assistant")
|
||||
|
||||
|
||||
class BaseTestAssistantContextAggregator:
|
||||
CONTEXT_CLASS = None # To be set in subclasses
|
||||
USER_AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
ASSISTANT_AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [] # To be set in subclasses
|
||||
AGGREGATOR_CLASS = None # To be set in subclasses
|
||||
EXPECTED_CONTEXT_FRAMES = None # To be set in subclasses
|
||||
|
||||
def create_assistant_aggregator_params(
|
||||
self, **kwargs
|
||||
@@ -728,12 +612,10 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
async def test_empty(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
|
||||
expected_down_frames = []
|
||||
await run_test(
|
||||
@@ -744,18 +626,16 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
async def test_single_text(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello Pipecat!"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -765,12 +645,10 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
async def test_multiple_text(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=self.create_assistant_aggregator_params(expect_stripped_words=False)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -781,7 +659,7 @@ class BaseTestAssistantContextAggregator:
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -791,12 +669,10 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
async def test_multiple_text_stripped(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame(text="Hello"),
|
||||
@@ -805,7 +681,7 @@ class BaseTestAssistantContextAggregator:
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -815,12 +691,10 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
async def test_multiple_llm_responses(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=self.create_assistant_aggregator_params(expect_stripped_words=False)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -833,10 +707,7 @@ class BaseTestAssistantContextAggregator:
|
||||
TextFrame(text="you?"),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES]
|
||||
await run_test(
|
||||
aggregator,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -847,12 +718,10 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
async def test_multiple_llm_responses_interruption(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(
|
||||
aggregator = self.AGGREGATOR_CLASS(
|
||||
context, params=self.create_assistant_aggregator_params(expect_stripped_words=False)
|
||||
)
|
||||
frames_to_send = [
|
||||
@@ -868,9 +737,9 @@ class BaseTestAssistantContextAggregator:
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
InterruptionFrame,
|
||||
*self.ASSISTANT_EXPECTED_CONTEXT_FRAMES,
|
||||
*self.EXPECTED_CONTEXT_FRAMES,
|
||||
]
|
||||
await run_test(
|
||||
aggregator,
|
||||
@@ -882,12 +751,10 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
async def test_function_call(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
@@ -913,9 +780,7 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
async def test_function_call_on_context_updated(self):
|
||||
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
|
||||
assert self.ASSISTANT_AGGREGATOR_CLASS is not None, (
|
||||
"ASSISTANT_AGGREGATOR_CLASS must be set in a subclass"
|
||||
)
|
||||
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
|
||||
|
||||
context_updated = False
|
||||
|
||||
@@ -924,7 +789,7 @@ class BaseTestAssistantContextAggregator:
|
||||
context_updated = True
|
||||
|
||||
context = self.CONTEXT_CLASS()
|
||||
aggregator = self.ASSISTANT_AGGREGATOR_CLASS(context)
|
||||
aggregator = self.AGGREGATOR_CLASS(context)
|
||||
frames_to_send = [
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
@@ -959,14 +824,7 @@ class BaseTestAssistantContextAggregator:
|
||||
|
||||
class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = LLMUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = LLMAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = LLMUserContextAggregator
|
||||
|
||||
|
||||
#
|
||||
@@ -978,14 +836,7 @@ class TestAnthropicUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AnthropicLLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = AnthropicUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = AnthropicUserContextAggregator
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
@@ -998,14 +849,8 @@ class TestAnthropicAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AnthropicLLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = AnthropicUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = AnthropicAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
@@ -1026,14 +871,7 @@ class TestAWSBedrockUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AWSBedrockLLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
@@ -1046,14 +884,8 @@ class TestAWSBedrockAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = AWSBedrockLLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = AWSBedrockUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_multi_content(
|
||||
self, context: OpenAILLMContext, content_index: int, index: int, content: str
|
||||
@@ -1076,14 +908,7 @@ class TestGoogleUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = GoogleLLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = GoogleUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = GoogleAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = GoogleUserContextAggregator
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
obj = context.messages[index].to_json_dict()
|
||||
@@ -1100,14 +925,8 @@ class TestGoogleAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = GoogleLLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = GoogleUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = GoogleAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = GoogleAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
def check_message_content(self, context: OpenAILLMContext, index: int, content: str):
|
||||
obj = context.messages[index].to_json_dict()
|
||||
@@ -1133,53 +952,26 @@ class TestOpenAIUserContextAggregator(
|
||||
BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = OpenAIUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = OpenAIUserContextAggregator
|
||||
|
||||
|
||||
class TestOpenAIAssistantContextAggregator(
|
||||
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
CONTEXT_FRAME_CLASS = OpenAILLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = OpenAIUserContextAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
]
|
||||
AGGREGATOR_CLASS = OpenAIAssistantContextAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame]
|
||||
|
||||
|
||||
#
|
||||
# Universal
|
||||
#
|
||||
|
||||
|
||||
class TestLLMUserAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase):
|
||||
CONTEXT_CLASS = LLMContext
|
||||
CONTEXT_FRAME_CLASS = LLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = LLMUserAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = LLMAssistantAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
|
||||
|
||||
class TestLLMAssistantAggregator(
|
||||
BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
CONTEXT_CLASS = LLMContext
|
||||
CONTEXT_FRAME_CLASS = LLMContextFrame
|
||||
USER_AGGREGATOR_CLASS = LLMUserAggregator
|
||||
USER_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame]
|
||||
ASSISTANT_AGGREGATOR_CLASS = LLMAssistantAggregator
|
||||
ASSISTANT_EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
CONTEXT_CLASS = OpenAILLMContext
|
||||
AGGREGATOR_CLASS = LLMAssistantAggregator
|
||||
EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame]
|
||||
|
||||
# Override to remove 'expect_stripped_words' parameter, which is deprecated
|
||||
# for LLMAssistantAggregator
|
||||
|
||||
Reference in New Issue
Block a user