Files
pipecat/tests/test_user_turn_start_strategy.py
Aleix Conchillo Flaqué c2a0735975 MinWordsUserTurnStartStrategy: don't aggregate transcriptions
If we aggregate transcriptions we will get incorrect interruptions. For example,
if we have a strategy with min_words=3 and we say "One" and pause, then "Two"
and pause and then "Three", this would trigger the start of the turn when it
shouldn't. We should only look at the incoming transcription text and don't
aggregate it with the previous.
2026-01-16 11:16:06 -08:00

202 lines
6.7 KiB
Python

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
InterimTranscriptionFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.turns.user_start import (
ExternalUserTurnStartStrategy,
MinWordsUserTurnStartStrategy,
TranscriptionUserTurnStartStrategy,
VADUserTurnStartStrategy,
)
class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
async def test_bot_speaking_transcriptions(self):
strategy = MinWordsUserTurnStartStrategy(min_words=2)
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(BotStartedSpeakingFrame())
await strategy.process_frame(TranscriptionFrame(text="Hello", user_id="cat", timestamp=""))
self.assertFalse(should_start)
await strategy.process_frame(
TranscriptionFrame(text="Hello there!", user_id="cat", timestamp="")
)
self.assertTrue(should_start)
# Reset and check again
should_start = None
await strategy.reset()
await strategy.process_frame(BotStartedSpeakingFrame())
await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""))
self.assertFalse(should_start)
await strategy.process_frame(
TranscriptionFrame(text="How are you?", user_id="cat", timestamp="")
)
self.assertTrue(should_start)
async def test_bot_speaking_singlw_words(self):
strategy = MinWordsUserTurnStartStrategy(min_words=3)
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(BotStartedSpeakingFrame())
await strategy.process_frame(TranscriptionFrame(text="One", user_id="cat", timestamp=""))
self.assertFalse(should_start)
await strategy.process_frame(TranscriptionFrame(text="Two", user_id="cat", timestamp=""))
self.assertFalse(should_start)
await strategy.process_frame(TranscriptionFrame(text="Three", user_id="cat", timestamp=""))
self.assertFalse(should_start)
async def test_bot_speaking_interim_transcriptions(self):
strategy = MinWordsUserTurnStartStrategy(min_words=2)
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(BotStartedSpeakingFrame())
await strategy.process_frame(
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp="")
)
self.assertFalse(should_start)
await strategy.process_frame(BotStartedSpeakingFrame())
await strategy.process_frame(
InterimTranscriptionFrame(text="Hello there!", user_id="cat", timestamp="")
)
self.assertTrue(should_start)
async def test_bot_speaking_all_transcriptions(self):
strategy = MinWordsUserTurnStartStrategy(min_words=2)
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(BotStartedSpeakingFrame())
await strategy.process_frame(
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp="")
)
self.assertFalse(should_start)
await strategy.process_frame(
TranscriptionFrame(text="Hello there!", user_id="cat", timestamp="")
)
self.assertTrue(should_start)
async def test_bot_not_speaking_transcriptions(self):
strategy = MinWordsUserTurnStartStrategy(min_words=2)
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(TranscriptionFrame(text="Hello", user_id="cat", timestamp=""))
self.assertTrue(should_start)
async def test_bot_not_speaking_interim_transcriptions(self):
strategy = MinWordsUserTurnStartStrategy(min_words=2)
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp="")
)
self.assertTrue(should_start)
class TestVADUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
async def test_vad_strategy(self):
strategy = VADUserTurnStartStrategy()
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(VADUserStoppedSpeakingFrame())
self.assertFalse(should_start)
await strategy.process_frame(VADUserStartedSpeakingFrame())
self.assertTrue(should_start)
class TestTranscriptionUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
async def test_transcription_strategy(self):
strategy = TranscriptionUserTurnStartStrategy()
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(VADUserStartedSpeakingFrame())
self.assertFalse(should_start)
await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="", timestamp="now"))
self.assertTrue(should_start)
class TestExternalUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
async def test_external_strategy(self):
strategy = ExternalUserTurnStartStrategy()
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(VADUserStartedSpeakingFrame())
self.assertFalse(should_start)
await strategy.process_frame(UserStartedSpeakingFrame())
self.assertTrue(should_start)