diff --git a/tests/test_transcript_processor.py b/tests/test_transcript_processor.py index 1c6db277f..5f80b3ca6 100644 --- a/tests/test_transcript_processor.py +++ b/tests/test_transcript_processor.py @@ -275,7 +275,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): # First update should be interrupted message first_message = received_updates[0].messages[0] self.assertEqual(first_message.role, "assistant") - self.assertEqual(first_message.content, "Hello world !") + self.assertEqual(first_message.content, "Hello world!") self.assertIsNotNone(first_message.timestamp) # Second update should be new response @@ -426,3 +426,299 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase): self.assertEqual(received_updates[0].content, "User message") self.assertEqual(received_updates[1].role, "assistant") self.assertEqual(received_updates[1].content, "Assistant message") + + async def test_text_fragments_with_spaces(self): + """Test aggregating text fragments with various spacing patterns""" + processor = AssistantTranscriptProcessor() + + # Track received updates + received_updates = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + # Test the specific pattern shared + frames_to_send = [ + BotStartedSpeakingFrame(), + SleepFrame(sleep=0.1), + TTSTextFrame(text="Hello"), + TTSTextFrame(text=" there"), + TTSTextFrame(text="!"), + TTSTextFrame(text=" How"), + TTSTextFrame(text="'s"), + TTSTextFrame(text=" it"), + TTSTextFrame(text=" going"), + TTSTextFrame(text="?"), + BotStoppedSpeakingFrame(), + ] + + expected_down_frames = [ + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TranscriptionUpdateFrame, + ] + + # Run test + received_frames, _ = await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + # Verify result + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertEqual(message.role, "assistant") + # Should be properly joined without extra spaces + self.assertEqual(message.content, "Hello there! How's it going?") + + async def test_mixed_spacing_styles(self): + """Test handling mixed word-by-word and pre-spaced fragments""" + processor = AssistantTranscriptProcessor() + + received_updates = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + # Mix of spacing styles within the same utterance + frames_to_send = [ + BotStartedSpeakingFrame(), + SleepFrame(sleep=0.1), + # Word-by-word style + TTSTextFrame(text="First"), + TTSTextFrame(text="style."), + # Pre-spaced style + TTSTextFrame(text=" Second"), + TTSTextFrame(text=" style"), + TTSTextFrame(text="!"), + BotStoppedSpeakingFrame(), + ] + + expected_down_frames = [ + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TranscriptionUpdateFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertEqual(message.content, "First style. Second style!") + + async def test_punctuation_handling(self): + """Test handling of various punctuation patterns""" + processor = AssistantTranscriptProcessor() + + received_updates = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + # Test various punctuation types + frames_to_send = [ + BotStartedSpeakingFrame(), + SleepFrame(sleep=0.1), + TTSTextFrame(text="Commas"), + TTSTextFrame(text=","), + TTSTextFrame(text="colons"), + TTSTextFrame(text=":"), + TTSTextFrame(text="semicolons"), + TTSTextFrame(text=";"), + TTSTextFrame(text="quotes"), + TTSTextFrame(text="'"), + TTSTextFrame(text="and"), + TTSTextFrame(text='"'), + TTSTextFrame(text="double quotes"), + TTSTextFrame(text="!"), + BotStoppedSpeakingFrame(), + ] + + expected_down_frames = [ + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TranscriptionUpdateFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertEqual( + message.content, "Commas, colons: semicolons; quotes' and\" double quotes!" + ) + + async def test_complex_mixed_case(self): + """Test a complex mix of patterns to ensure robustness""" + processor = AssistantTranscriptProcessor() + + received_updates = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + # Complex mixed case with various patterns + frames_to_send = [ + BotStartedSpeakingFrame(), + SleepFrame(sleep=0.1), + # Pre-spaced fragments + TTSTextFrame(text="Hello"), + TTSTextFrame(text=" there"), + TTSTextFrame(text="!"), + # Sentence boundary + TTSTextFrame(text=" I'm"), + TTSTextFrame(text=" testing"), + TTSTextFrame(text=" spacing"), + TTSTextFrame(text="."), + # Word-by-word fragments + TTSTextFrame(text="Does"), + TTSTextFrame(text="this"), + TTSTextFrame(text="work"), + TTSTextFrame(text="correctly"), + TTSTextFrame(text="?"), + # Mixed punctuation and spacing + TTSTextFrame(text=" Let's"), + TTSTextFrame(text=" see:"), + TTSTextFrame(text="commas"), + TTSTextFrame(text=","), + TTSTextFrame(text=" semicolons"), + TTSTextFrame(text=";"), + TTSTextFrame(text=" and"), + TTSTextFrame(text=" quotes"), + TTSTextFrame(text="'"), + TTSTextFrame(text="!"), + BotStoppedSpeakingFrame(), + ] + + expected_down_frames = [ + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TranscriptionUpdateFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + expected = "Hello there! I'm testing spacing. Does this work correctly? Let's see: commas, semicolons; and quotes'!" + self.assertEqual(message.content, expected) + + async def test_multiple_consecutive_punctuation(self): + """Test handling of multiple consecutive punctuation marks""" + processor = AssistantTranscriptProcessor() + + received_updates = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + frames_to_send = [ + BotStartedSpeakingFrame(), + SleepFrame(sleep=0.1), + TTSTextFrame(text="Wow"), + TTSTextFrame(text="!"), + TTSTextFrame(text="!"), + TTSTextFrame(text="!"), + TTSTextFrame(text=" That's"), + TTSTextFrame(text=" amazing"), + TTSTextFrame(text="..."), + TTSTextFrame(text=" Don't"), + TTSTextFrame(text=" you"), + TTSTextFrame(text=" think"), + TTSTextFrame(text="?"), + TTSTextFrame(text="?"), + BotStoppedSpeakingFrame(), + ] + + expected_down_frames = [ + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TTSTextFrame, + TranscriptionUpdateFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertEqual(message.content, "Wow!!! That's amazing... Don't you think??")