diff --git a/src/pipecat/turns/user/transcription_user_turn_start_strategy.py b/src/pipecat/turns/user/transcription_user_turn_start_strategy.py new file mode 100644 index 000000000..6a5a7c907 --- /dev/null +++ b/src/pipecat/turns/user/transcription_user_turn_start_strategy.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""User turn start strategy based on transcriptions.""" + +from pipecat.frames.frames import BotStartedSpeakingFrame, Frame, TranscriptionFrame +from pipecat.turns.user.base_user_turn_start_strategy import BaseUserTurnStartStrategy + + +class TranscriptionUserTurnStartStrategy(BaseUserTurnStartStrategy): + """User turn start strategy based on transcriptions. + + This strategy signals the start of a user turn when a transcription is + received while the bot is speaking. It is useful as a fallback in scenarios + where VAD-based detection fails (for example, when the user speaks very + softly) but the STT service still produces transcriptions. + + """ + + def __init__(self): + """Initialize the base interruption strategy.""" + super().__init__() + self._bot_speaking = False + + async def reset(self): + """Reset the interruption strategy.""" + await super().reset() + self._bot_speaking = False + + async def process_frame(self, frame: Frame): + """Process an incoming frame to detect the start of a user turn. + + Args: + frame: The frame to be processed. + """ + await super().process_frame(frame) + + if isinstance(frame, BotStartedSpeakingFrame): + await self._handle_bot_started_speaking(frame) + elif isinstance(frame, TranscriptionFrame): + await self._handle_transcription(frame) + + async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame): + self._bot_speaking = True + + async def _handle_transcription(self, _: TranscriptionFrame): + if self._bot_speaking: + await self.trigger_user_turn_started() diff --git a/tests/test_user_turn_start_strategy.py b/tests/test_user_turn_start_strategy.py index 144dd15cc..9402c1793 100644 --- a/tests/test_user_turn_start_strategy.py +++ b/tests/test_user_turn_start_strategy.py @@ -7,12 +7,16 @@ import unittest from pipecat.frames.frames import ( + BotStartedSpeakingFrame, InterimTranscriptionFrame, TranscriptionFrame, VADUserStartedSpeakingFrame, VADUserStoppedSpeakingFrame, ) from pipecat.turns.user.min_words_user_turn_start_strategy import MinWordsUserTurnStartStrategy +from pipecat.turns.user.transcription_user_turn_start_strategy import ( + TranscriptionUserTurnStartStrategy, +) from pipecat.turns.user.vad_user_turn_start_strategy import VADUserTurnStartStrategy @@ -104,3 +108,24 @@ class TestVADUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase): await strategy.process_frame(VADUserStartedSpeakingFrame()) self.assertTrue(should_start) + + +class TestTranscriptionUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase): + async def test_transcription_strategy(self): + strategy = TranscriptionUserTurnStartStrategy() + + should_start = None + + @strategy.event_handler("on_user_turn_started") + async def on_user_turn_started(strategy): + nonlocal should_start + should_start = True + + await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="", timestamp="now")) + self.assertFalse(should_start) + + await strategy.process_frame(BotStartedSpeakingFrame()) + self.assertFalse(should_start) + + await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="", timestamp="now")) + self.assertTrue(should_start)