turns: add TranscriptionUserTurnStartStrategy
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""User turn start strategy based on transcriptions."""
|
||||
|
||||
from pipecat.frames.frames import BotStartedSpeakingFrame, Frame, TranscriptionFrame
|
||||
from pipecat.turns.user.base_user_turn_start_strategy import BaseUserTurnStartStrategy
|
||||
|
||||
|
||||
class TranscriptionUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
"""User turn start strategy based on transcriptions.
|
||||
|
||||
This strategy signals the start of a user turn when a transcription is
|
||||
received while the bot is speaking. It is useful as a fallback in scenarios
|
||||
where VAD-based detection fails (for example, when the user speaks very
|
||||
softly) but the STT service still produces transcriptions.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the base interruption strategy."""
|
||||
super().__init__()
|
||||
self._bot_speaking = False
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the interruption strategy."""
|
||||
await super().reset()
|
||||
self._bot_speaking = False
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to detect the start of a user turn.
|
||||
|
||||
Args:
|
||||
frame: The frame to be processed.
|
||||
"""
|
||||
await super().process_frame(frame)
|
||||
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._handle_bot_started_speaking(frame)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
|
||||
async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame):
|
||||
self._bot_speaking = True
|
||||
|
||||
async def _handle_transcription(self, _: TranscriptionFrame):
|
||||
if self._bot_speaking:
|
||||
await self.trigger_user_turn_started()
|
||||
@@ -7,12 +7,16 @@
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.user.min_words_user_turn_start_strategy import MinWordsUserTurnStartStrategy
|
||||
from pipecat.turns.user.transcription_user_turn_start_strategy import (
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user.vad_user_turn_start_strategy import VADUserTurnStartStrategy
|
||||
|
||||
|
||||
@@ -104,3 +108,24 @@ class TestVADUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
self.assertTrue(should_start)
|
||||
|
||||
|
||||
class TestTranscriptionUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_transcription_strategy(self):
|
||||
strategy = TranscriptionUserTurnStartStrategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="", timestamp="now"))
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(BotStartedSpeakingFrame())
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="", timestamp="now"))
|
||||
self.assertTrue(should_start)
|
||||
|
||||
Reference in New Issue
Block a user