diff --git a/tests/test_bot_turn_start_strategy.py b/tests/test_bot_turn_start_strategy.py index 3908cb974..8e4396fc2 100644 --- a/tests/test_bot_turn_start_strategy.py +++ b/tests/test_bot_turn_start_strategy.py @@ -10,10 +10,13 @@ import unittest from pipecat.frames.frames import ( InterimTranscriptionFrame, TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, VADUserStartedSpeakingFrame, VADUserStoppedSpeakingFrame, ) from pipecat.turns.bot import TranscriptionBotTurnStartStrategy +from pipecat.turns.bot.external_bot_turn_start_strategy import ExternalBotTurnStartStrategy from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams AGGREGATION_TIMEOUT = 0.1 @@ -472,3 +475,35 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase): # at least the aggregation timeout. await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1) self.assertTrue(should_start) + + +class TestExternalBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase): + async def test_external_strategy(self): + strategy = ExternalBotTurnStartStrategy() + + should_start = None + + @strategy.event_handler("on_bot_turn_started") + async def on_bot_turn_started(strategy, enable_user_speaking_frames): + nonlocal should_start + should_start = True + + await strategy.process_frame(VADUserStartedSpeakingFrame()) + self.assertFalse(should_start) + + await strategy.process_frame(UserStartedSpeakingFrame()) + self.assertFalse(should_start) + + await strategy.process_frame(UserStoppedSpeakingFrame()) + self.assertFalse(should_start) + + await strategy.process_frame(UserStartedSpeakingFrame()) + self.assertFalse(should_start) + + await strategy.process_frame( + TranscriptionFrame(text="How are you?", user_id="cat", timestamp="") + ) + self.assertFalse(should_start) + + await strategy.process_frame(UserStoppedSpeakingFrame()) + self.assertTrue(should_start) diff --git a/tests/test_user_turn_start_strategy.py b/tests/test_user_turn_start_strategy.py index da1c46e5a..1212e075a 100644 --- a/tests/test_user_turn_start_strategy.py +++ b/tests/test_user_turn_start_strategy.py @@ -10,10 +10,12 @@ from pipecat.frames.frames import ( BotStartedSpeakingFrame, InterimTranscriptionFrame, TranscriptionFrame, + UserStartedSpeakingFrame, VADUserStartedSpeakingFrame, VADUserStoppedSpeakingFrame, ) from pipecat.turns.user import ( + ExternalUserTurnStartStrategy, MinWordsUserTurnStartStrategy, TranscriptionUserTurnStartStrategy, VADUserTurnStartStrategy, @@ -162,3 +164,21 @@ class TestTranscriptionUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase): await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="", timestamp="now")) self.assertTrue(should_start) + + +class TestExternalUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase): + async def test_external_strategy(self): + strategy = ExternalUserTurnStartStrategy() + + should_start = None + + @strategy.event_handler("on_user_turn_started") + async def on_user_turn_started(strategy, enable_user_speaking_frames): + nonlocal should_start + should_start = True + + await strategy.process_frame(VADUserStartedSpeakingFrame()) + self.assertFalse(should_start) + + await strategy.process_frame(UserStartedSpeakingFrame()) + self.assertTrue(should_start)