Compare commits
1 Commits
pk/optiona
...
aleix/dont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61a7922bea |
@@ -12,7 +12,6 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.turns.user_start.base_user_turn_start_strategy import BaseUserTurnStartStrategy
|
||||
@@ -23,23 +22,19 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
|
||||
This strategy signals the start of a user turn once the user has spoken at
|
||||
least a specified number of words, as determined from transcription frames.
|
||||
Optionally, interim transcriptions can be used for earlier detection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *, min_words: int, use_interim: bool = True, **kwargs):
|
||||
def __init__(self, *, min_words: int, **kwargs):
|
||||
"""Initialize the minimum words bot turn start strategy.
|
||||
|
||||
Args:
|
||||
min_words: Minimum number of spoken words required to trigger the
|
||||
start of a user turn.
|
||||
use_interim: Whether to consider interim transcription frames for
|
||||
earlier detection.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._min_words = min_words
|
||||
self._use_interim = use_interim
|
||||
self._bot_speaking = False
|
||||
self._text = ""
|
||||
|
||||
@@ -66,8 +61,6 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
await self._handle_bot_stopped_speaking(frame)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame) and self._use_interim:
|
||||
await self._handle_interim_transcription(frame)
|
||||
|
||||
async def _handle_bot_started_speaking(self, frame: BotStartedSpeakingFrame):
|
||||
"""Handle bot started speaking frame.
|
||||
@@ -95,7 +88,7 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
Args:
|
||||
frame: The transcription frame to be processed.
|
||||
"""
|
||||
self._text += frame.text
|
||||
self._text += f" {frame.text}"
|
||||
|
||||
min_words = self._min_words if self._bot_speaking else 1
|
||||
|
||||
@@ -109,22 +102,3 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
|
||||
if should_trigger:
|
||||
await self.trigger_user_turn_started()
|
||||
|
||||
async def _handle_interim_transcription(self, frame: InterimTranscriptionFrame):
|
||||
"""Handle an interim transcription frame and check word count.
|
||||
|
||||
Args:
|
||||
frame: The interim transcription frame to be processed.
|
||||
"""
|
||||
min_words = self._min_words if self._bot_speaking else 1
|
||||
|
||||
word_count = len(frame.text.split())
|
||||
should_trigger = word_count >= min_words
|
||||
|
||||
logger.debug(
|
||||
f"{self} interim=True should_trigger={should_trigger} num_spoken_words={word_count} "
|
||||
f"min_words={min_words} bot_speaking={self._bot_speaking}"
|
||||
)
|
||||
|
||||
if should_trigger:
|
||||
await self.trigger_user_turn_started()
|
||||
|
||||
@@ -24,7 +24,7 @@ from pipecat.turns.user_start import (
|
||||
|
||||
class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_bot_speaking_transcriptions(self):
|
||||
strategy = MinWordsUserTurnStartStrategy(min_words=2)
|
||||
strategy = MinWordsUserTurnStartStrategy(min_words=3)
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(
|
||||
TranscriptionFrame(text=" there!", user_id="cat", timestamp="")
|
||||
TranscriptionFrame(text=" there friend!", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
@@ -50,54 +50,14 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""))
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(TranscriptionFrame(text="How", user_id="cat", timestamp=""))
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_bot_speaking_interim_transcriptions(self):
|
||||
strategy = MinWordsUserTurnStartStrategy(min_words=2)
|
||||
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy, params):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
await strategy.process_frame(BotStartedSpeakingFrame())
|
||||
await strategy.process_frame(
|
||||
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(BotStartedSpeakingFrame())
|
||||
await strategy.process_frame(
|
||||
InterimTranscriptionFrame(text="Hello there!", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_bot_speaking_all_transcriptions(self):
|
||||
strategy = MinWordsUserTurnStartStrategy(min_words=2)
|
||||
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy, params):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
await strategy.process_frame(BotStartedSpeakingFrame())
|
||||
await strategy.process_frame(
|
||||
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(
|
||||
TranscriptionFrame(text="Hello there!", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_bot_not_speaking_transcriptions(self):
|
||||
strategy = MinWordsUserTurnStartStrategy(min_words=2)
|
||||
|
||||
@@ -111,21 +71,6 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
await strategy.process_frame(TranscriptionFrame(text="Hello", user_id="cat", timestamp=""))
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_bot_not_speaking_interim_transcriptions(self):
|
||||
strategy = MinWordsUserTurnStartStrategy(min_words=2)
|
||||
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy, params):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
await strategy.process_frame(
|
||||
InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
|
||||
class TestVADUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_vad_strategy(self):
|
||||
|
||||
Reference in New Issue
Block a user