Files
pipecat/tests/test_context_aggregators_universal.py
2025-12-23 15:51:06 -08:00

236 lines
6.9 KiB
Python

#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.frames.frames import (
InterruptionFrame,
LLMContextFrame,
LLMMessagesAppendFrame,
LLMMessagesUpdateFrame,
LLMRunFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMUserAggregator,
LLMUserAggregatorParams,
)
from pipecat.tests.utils import SleepFrame, run_test
from pipecat.turns.bot.transcription_bot_turn_start_strategy import (
TranscriptionBotTurnStartStrategy,
)
from pipecat.turns.turn_start_strategies import TurnStartStrategies
USER_TURN_END_TIMEOUT = 0.2
TRANSCRIPTION_TIMEOUT = 0.1
class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
async def test_llm_run(self):
context = LLMContext()
pipeline = Pipeline([LLMUserAggregator(context)])
frames_to_send = [LLMRunFrame()]
expected_down_frames = [LLMContextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
async def test_llm_messages_append(self):
context = LLMContext()
pipeline = Pipeline([LLMUserAggregator(context)])
frames_to_send = [
LLMMessagesAppendFrame(
messages=[
{
"role": "user",
"content": "Hi there!",
}
]
)
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
)
assert context.messages[0]["content"] == "Hi there!"
async def test_llm_messages_append_run(self):
context = LLMContext()
pipeline = Pipeline([LLMUserAggregator(context)])
frames_to_send = [
LLMMessagesAppendFrame(
messages=[
{
"role": "user",
"content": "Hi there!",
}
],
run_llm=True,
)
]
expected_down_frames = [LLMContextFrame]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
assert context.messages[0]["content"] == "Hi there!"
async def test_llm_messages_update(self):
context = LLMContext()
pipeline = Pipeline([LLMUserAggregator(context)])
frames_to_send = [
LLMMessagesUpdateFrame(
messages=[
{
"role": "user",
"content": "Hi there!",
}
]
)
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
)
assert context.messages[0]["content"] == "Hi there!"
async def test_llm_messages_update_run(self):
context = LLMContext()
pipeline = Pipeline([LLMUserAggregator(context)])
frames_to_send = [
LLMMessagesUpdateFrame(
messages=[
{
"role": "user",
"content": "Hi there!",
}
],
run_llm=True,
)
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
)
assert context.messages[0]["content"] == "Hi there!"
async def test_default_turn_start_strategies(self):
context = LLMContext()
user_aggregator = LLMUserAggregator(context)
pipeline = Pipeline([user_aggregator])
frames_to_send = [
VADUserStartedSpeakingFrame(),
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
SleepFrame(),
VADUserStoppedSpeakingFrame(),
]
expected_down_frames = [
VADUserStartedSpeakingFrame,
UserStartedSpeakingFrame,
InterruptionFrame,
VADUserStoppedSpeakingFrame,
UserStoppedSpeakingFrame,
LLMContextFrame,
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
async def test_user_turn_end_timeout_no_transcription(self):
context = LLMContext()
user_aggregator = LLMUserAggregator(
context,
params=LLMUserAggregatorParams(user_turn_end_timeout=USER_TURN_END_TIMEOUT),
)
timeout = False
@user_aggregator.event_handler("on_user_turn_end_timeout")
async def on_user_turn_end_timeout(aggregator):
nonlocal timeout
timeout = True
pipeline = Pipeline([user_aggregator])
frames_to_send = [
VADUserStartedSpeakingFrame(),
VADUserStoppedSpeakingFrame(),
SleepFrame(sleep=USER_TURN_END_TIMEOUT + 0.1),
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
)
self.assertTrue(timeout)
async def test_user_turn_end_timeout_transcription(self):
context = LLMContext()
user_aggregator = LLMUserAggregator(
context,
params=LLMUserAggregatorParams(user_turn_end_timeout=USER_TURN_END_TIMEOUT),
)
timeout = False
bot_turn = False
@user_aggregator.event_handler("on_bot_turn_started")
async def on_bot_turn_started(aggregator, strategy):
nonlocal bot_turn
bot_turn = True
@user_aggregator.event_handler("on_user_turn_end_timeout")
async def on_user_turn_end_timeout(aggregator):
nonlocal timeout
timeout = True
pipeline = Pipeline([user_aggregator])
frames_to_send = [
VADUserStartedSpeakingFrame(),
VADUserStoppedSpeakingFrame(),
SleepFrame(sleep=USER_TURN_END_TIMEOUT - 0.1),
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
SleepFrame(sleep=USER_TURN_END_TIMEOUT - 0.1),
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT),
]
await run_test(
pipeline,
frames_to_send=frames_to_send,
pipeline_params=PipelineParams(
turn_start_strategies=TurnStartStrategies(
bot=[TranscriptionBotTurnStartStrategy(timeout=TRANSCRIPTION_TIMEOUT)],
)
),
)
# The transcription strategy should kick-in before the user turn end timeout.
self.assertTrue(bot_turn)
self.assertFalse(timeout)