MinWordsUserTurnStartStrategy: don't aggregate transcriptions
If we aggregate transcriptions we will get incorrect interruptions. For example, if we have a strategy with min_words=3 and we say "One" and pause, then "Two" and pause and then "Three", this would trigger the start of the turn when it shouldn't. We should only look at the incoming transcription text and don't aggregate it with the previous.
This commit is contained in:
1
changelog/3462.fixed.md
Normal file
1
changelog/3462.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `MinWordsUserTurnStartStrategy` to not aggregate transcriptions, preventing incorrect turn starts when words are spoken with pauses between them.
|
||||
@@ -41,12 +41,10 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
self._min_words = min_words
|
||||
self._use_interim = use_interim
|
||||
self._bot_speaking = False
|
||||
self._text = ""
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
await super().reset()
|
||||
self._text = ""
|
||||
self._bot_speaking = False
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
@@ -67,7 +65,7 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame) and self._use_interim:
|
||||
await self._handle_interim_transcription(frame)
|
||||
await self._handle_transcription(frame)
|
||||
|
||||
async def _handle_bot_started_speaking(self, frame: BotStartedSpeakingFrame):
|
||||
"""Handle bot started speaking frame.
|
||||
@@ -89,41 +87,21 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
"""
|
||||
self._bot_speaking = False
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
|
||||
"""Handle a completed transcription frame and check word count.
|
||||
|
||||
Args:
|
||||
frame: The transcription frame to be processed.
|
||||
"""
|
||||
self._text += frame.text
|
||||
|
||||
min_words = self._min_words if self._bot_speaking else 1
|
||||
|
||||
word_count = len(self._text.split())
|
||||
should_trigger = word_count >= min_words
|
||||
|
||||
logger.debug(
|
||||
f"{self} should_trigger={should_trigger} num_spoken_words={word_count} "
|
||||
f"min_words={min_words} bot_speaking={self._bot_speaking}"
|
||||
)
|
||||
|
||||
if should_trigger:
|
||||
await self.trigger_user_turn_started()
|
||||
|
||||
async def _handle_interim_transcription(self, frame: InterimTranscriptionFrame):
|
||||
"""Handle an interim transcription frame and check word count.
|
||||
|
||||
Args:
|
||||
frame: The interim transcription frame to be processed.
|
||||
"""
|
||||
min_words = self._min_words if self._bot_speaking else 1
|
||||
|
||||
word_count = len(frame.text.split())
|
||||
should_trigger = word_count >= min_words
|
||||
is_interim = isinstance(frame, InterimTranscriptionFrame)
|
||||
|
||||
logger.debug(
|
||||
f"{self} interim=True should_trigger={should_trigger} num_spoken_words={word_count} "
|
||||
f"min_words={min_words} bot_speaking={self._bot_speaking}"
|
||||
f"{self} should_trigger={should_trigger} num_spoken_words={word_count} "
|
||||
f"min_words={min_words} bot_speaking={self._bot_speaking} interim_transcription={is_interim}"
|
||||
)
|
||||
|
||||
if should_trigger:
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestUserTurnController(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(should_start, 0)
|
||||
|
||||
await controller.process_frame(
|
||||
TranscriptionFrame(text=" two three!", user_id="cat", timestamp="")
|
||||
TranscriptionFrame(text="One two three!", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertEqual(should_start, 1)
|
||||
|
||||
@@ -92,13 +92,11 @@ class TestUserTurnController(unittest.IsolatedAsyncioTestCase):
|
||||
await asyncio.sleep(USER_TURN_STOP_TIMEOUT + 0.1)
|
||||
|
||||
await controller.process_frame(BotStartedSpeakingFrame())
|
||||
await controller.process_frame(
|
||||
TranscriptionFrame(text="Hello", user_id="cat", timestamp="")
|
||||
)
|
||||
await controller.process_frame(TranscriptionFrame(text="Hi!", user_id="cat", timestamp=""))
|
||||
self.assertEqual(should_start, 1)
|
||||
|
||||
await controller.process_frame(
|
||||
TranscriptionFrame(text=" there friend!", user_id="cat", timestamp="")
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertEqual(should_start, 2)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(
|
||||
TranscriptionFrame(text=" there!", user_id="cat", timestamp="")
|
||||
TranscriptionFrame(text="Hello there!", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
@@ -55,6 +55,26 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_bot_speaking_singlw_words(self):
|
||||
strategy = MinWordsUserTurnStartStrategy(min_words=3)
|
||||
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy, params):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
await strategy.process_frame(BotStartedSpeakingFrame())
|
||||
await strategy.process_frame(TranscriptionFrame(text="One", user_id="cat", timestamp=""))
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(TranscriptionFrame(text="Two", user_id="cat", timestamp=""))
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(TranscriptionFrame(text="Three", user_id="cat", timestamp=""))
|
||||
self.assertFalse(should_start)
|
||||
|
||||
async def test_bot_speaking_interim_transcriptions(self):
|
||||
strategy = MinWordsUserTurnStartStrategy(min_words=2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user