MinWordsUserTurnStartStrategy: don't aggregate transcriptions

If we aggregate transcriptions we will get incorrect interruptions. For example,
if we have a strategy with min_words=3 and we say "One" and pause, then "Two"
and pause and then "Three", this would trigger the start of the turn when it
shouldn't. We should only look at the incoming transcription text and don't
aggregate it with the previous.
This commit is contained in:
Aleix Conchillo Flaqué
2026-01-15 10:32:23 -08:00
parent 41cb53f6c2
commit c2a0735975
4 changed files with 30 additions and 33 deletions

1
changelog/3462.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Fixed `MinWordsUserTurnStartStrategy` to not aggregate transcriptions, preventing incorrect turn starts when words are spoken with pauses between them.

View File

@@ -41,12 +41,10 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
self._min_words = min_words
self._use_interim = use_interim
self._bot_speaking = False
self._text = ""
async def reset(self):
"""Reset the strategy to its initial state."""
await super().reset()
self._text = ""
self._bot_speaking = False
async def process_frame(self, frame: Frame):
@@ -67,7 +65,7 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
elif isinstance(frame, TranscriptionFrame):
await self._handle_transcription(frame)
elif isinstance(frame, InterimTranscriptionFrame) and self._use_interim:
await self._handle_interim_transcription(frame)
await self._handle_transcription(frame)
async def _handle_bot_started_speaking(self, frame: BotStartedSpeakingFrame):
"""Handle bot started speaking frame.
@@ -89,41 +87,21 @@ class MinWordsUserTurnStartStrategy(BaseUserTurnStartStrategy):
"""
self._bot_speaking = False
async def _handle_transcription(self, frame: TranscriptionFrame):
async def _handle_transcription(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
"""Handle a completed transcription frame and check word count.
Args:
frame: The transcription frame to be processed.
"""
self._text += frame.text
min_words = self._min_words if self._bot_speaking else 1
word_count = len(self._text.split())
should_trigger = word_count >= min_words
logger.debug(
f"{self} should_trigger={should_trigger} num_spoken_words={word_count} "
f"min_words={min_words} bot_speaking={self._bot_speaking}"
)
if should_trigger:
await self.trigger_user_turn_started()
async def _handle_interim_transcription(self, frame: InterimTranscriptionFrame):
"""Handle an interim transcription frame and check word count.
Args:
frame: The interim transcription frame to be processed.
"""
min_words = self._min_words if self._bot_speaking else 1
word_count = len(frame.text.split())
should_trigger = word_count >= min_words
is_interim = isinstance(frame, InterimTranscriptionFrame)
logger.debug(
f"{self} interim=True should_trigger={should_trigger} num_spoken_words={word_count} "
f"min_words={min_words} bot_speaking={self._bot_speaking}"
f"{self} should_trigger={should_trigger} num_spoken_words={word_count} "
f"min_words={min_words} bot_speaking={self._bot_speaking} interim_transcription={is_interim}"
)
if should_trigger:

View File

@@ -84,7 +84,7 @@ class TestUserTurnController(unittest.IsolatedAsyncioTestCase):
self.assertEqual(should_start, 0)
await controller.process_frame(
TranscriptionFrame(text=" two three!", user_id="cat", timestamp="")
TranscriptionFrame(text="One two three!", user_id="cat", timestamp="")
)
self.assertEqual(should_start, 1)
@@ -92,13 +92,11 @@ class TestUserTurnController(unittest.IsolatedAsyncioTestCase):
await asyncio.sleep(USER_TURN_STOP_TIMEOUT + 0.1)
await controller.process_frame(BotStartedSpeakingFrame())
await controller.process_frame(
TranscriptionFrame(text="Hello", user_id="cat", timestamp="")
)
await controller.process_frame(TranscriptionFrame(text="Hi!", user_id="cat", timestamp=""))
self.assertEqual(should_start, 1)
await controller.process_frame(
TranscriptionFrame(text=" there friend!", user_id="cat", timestamp="")
TranscriptionFrame(text="How are you?", user_id="cat", timestamp="")
)
self.assertEqual(should_start, 2)

View File

@@ -38,7 +38,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
self.assertFalse(should_start)
await strategy.process_frame(
TranscriptionFrame(text=" there!", user_id="cat", timestamp="")
TranscriptionFrame(text="Hello there!", user_id="cat", timestamp="")
)
self.assertTrue(should_start)
@@ -55,6 +55,26 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
)
self.assertTrue(should_start)
async def test_bot_speaking_singlw_words(self):
strategy = MinWordsUserTurnStartStrategy(min_words=3)
should_start = None
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal should_start
should_start = True
await strategy.process_frame(BotStartedSpeakingFrame())
await strategy.process_frame(TranscriptionFrame(text="One", user_id="cat", timestamp=""))
self.assertFalse(should_start)
await strategy.process_frame(TranscriptionFrame(text="Two", user_id="cat", timestamp=""))
self.assertFalse(should_start)
await strategy.process_frame(TranscriptionFrame(text="Three", user_id="cat", timestamp=""))
self.assertFalse(should_start)
async def test_bot_speaking_interim_transcriptions(self):
strategy = MinWordsUserTurnStartStrategy(min_words=2)