diff --git a/changelog/3462.fixed.md b/changelog/3462.fixed.md new file mode 100644 index 000000000..f9ede6a53 --- /dev/null +++ b/changelog/3462.fixed.md @@ -0,0 +1 @@ +- Fixed `MinWordsUserTurnStartStrategy` to not aggregate transcriptions, preventing incorrect turn starts when words are spoken with pauses between them. diff --git a/src/pipecat/turns/user_start/min_words_user_turn_start_strategy.py b/src/pipecat/turns/user_start/min_words_user_turn_start_strategy.py index 1f156cef9..57a0ca0d8 100644 --- a/src/pipecat/turns/user_start/min_words_user_turn_start_strategy.py +++ b/src/pipecat/turns/user_start/min_words_user_turn_start_strategy.py @@ -41,12 +41,10 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy): self._min_words = min_words self._use_interim = use_interim self._bot_speaking = False - self._text = "" async def reset(self): """Reset the strategy to its initial state.""" await super().reset() - self._text = "" self._bot_speaking = False async def process_frame(self, frame: Frame): @@ -67,7 +65,7 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy): elif isinstance(frame, TranscriptionFrame): await self._handle_transcription(frame) elif isinstance(frame, InterimTranscriptionFrame) and self._use_interim: - await self._handle_interim_transcription(frame) + await self._handle_transcription(frame) async def _handle_bot_started_speaking(self, frame: BotStartedSpeakingFrame): """Handle bot started speaking frame. @@ -89,41 +87,21 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy): """ self._bot_speaking = False - async def _handle_transcription(self, frame: TranscriptionFrame): + async def _handle_transcription(self, frame: TranscriptionFrame | InterimTranscriptionFrame): """Handle a completed transcription frame and check word count. Args: frame: The transcription frame to be processed. """ - self._text += frame.text - - min_words = self._min_words if self._bot_speaking else 1 - - word_count = len(self._text.split()) - should_trigger = word_count >= min_words - - logger.debug( - f"{self} should_trigger={should_trigger} num_spoken_words={word_count} " - f"min_words={min_words} bot_speaking={self._bot_speaking}" - ) - - if should_trigger: - await self.trigger_user_turn_started() - - async def _handle_interim_transcription(self, frame: InterimTranscriptionFrame): - """Handle an interim transcription frame and check word count. - - Args: - frame: The interim transcription frame to be processed. - """ min_words = self._min_words if self._bot_speaking else 1 word_count = len(frame.text.split()) should_trigger = word_count >= min_words + is_interim = isinstance(frame, InterimTranscriptionFrame) logger.debug( - f"{self} interim=True should_trigger={should_trigger} num_spoken_words={word_count} " - f"min_words={min_words} bot_speaking={self._bot_speaking}" + f"{self} should_trigger={should_trigger} num_spoken_words={word_count} " + f"min_words={min_words} bot_speaking={self._bot_speaking} interim_transcription={is_interim}" ) if should_trigger: diff --git a/tests/test_user_turn_controller.py b/tests/test_user_turn_controller.py index 1c5cfcf85..d8f642594 100644 --- a/tests/test_user_turn_controller.py +++ b/tests/test_user_turn_controller.py @@ -84,7 +84,7 @@ class TestUserTurnController(unittest.IsolatedAsyncioTestCase): self.assertEqual(should_start, 0) await controller.process_frame( - TranscriptionFrame(text=" two three!", user_id="cat", timestamp="") + TranscriptionFrame(text="One two three!", user_id="cat", timestamp="") ) self.assertEqual(should_start, 1) @@ -92,13 +92,11 @@ class TestUserTurnController(unittest.IsolatedAsyncioTestCase): await asyncio.sleep(USER_TURN_STOP_TIMEOUT + 0.1) await controller.process_frame(BotStartedSpeakingFrame()) - await controller.process_frame( - TranscriptionFrame(text="Hello", user_id="cat", timestamp="") - ) + await controller.process_frame(TranscriptionFrame(text="Hi!", user_id="cat", timestamp="")) self.assertEqual(should_start, 1) await controller.process_frame( - TranscriptionFrame(text=" there friend!", user_id="cat", timestamp="") + TranscriptionFrame(text="How are you?", user_id="cat", timestamp="") ) self.assertEqual(should_start, 2) diff --git a/tests/test_user_turn_start_strategy.py b/tests/test_user_turn_start_strategy.py index 51856ff9c..dabd4593b 100644 --- a/tests/test_user_turn_start_strategy.py +++ b/tests/test_user_turn_start_strategy.py @@ -38,7 +38,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase): self.assertFalse(should_start) await strategy.process_frame( - TranscriptionFrame(text=" there!", user_id="cat", timestamp="") + TranscriptionFrame(text="Hello there!", user_id="cat", timestamp="") ) self.assertTrue(should_start) @@ -55,6 +55,26 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase): ) self.assertTrue(should_start) + async def test_bot_speaking_singlw_words(self): + strategy = MinWordsUserTurnStartStrategy(min_words=3) + + should_start = None + + @strategy.event_handler("on_user_turn_started") + async def on_user_turn_started(strategy, params): + nonlocal should_start + should_start = True + + await strategy.process_frame(BotStartedSpeakingFrame()) + await strategy.process_frame(TranscriptionFrame(text="One", user_id="cat", timestamp="")) + self.assertFalse(should_start) + + await strategy.process_frame(TranscriptionFrame(text="Two", user_id="cat", timestamp="")) + self.assertFalse(should_start) + + await strategy.process_frame(TranscriptionFrame(text="Three", user_id="cat", timestamp="")) + self.assertFalse(should_start) + async def test_bot_speaking_interim_transcriptions(self): strategy = MinWordsUserTurnStartStrategy(min_words=2)