diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py new file mode 100644 index 000000000..e1086e577 --- /dev/null +++ b/tests/test_pattern_pair_aggregator.py @@ -0,0 +1,147 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest +from unittest.mock import Mock + +from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator + + +class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.aggregator = PatternPairAggregator() + self.test_handler = Mock() + + # Add a test pattern + self.aggregator.add_pattern_pair( + pattern_id="test_pattern", + start_pattern="", + end_pattern="", + remove_match=True, + ) + + # Register the mock handler + self.aggregator.on_pattern_match("test_pattern", self.test_handler) + + async def test_pattern_match_and_removal(self): + # First part doesn't complete the pattern + result = self.aggregator.aggregate("Hello pattern") + self.assertIsNone(result) + self.assertEqual(self.aggregator.text, "Hello pattern") + + # Second part completes the pattern and includes an exclamation point + result = self.aggregator.aggregate(" content!") + + # Verify the handler was called with correct PatternMatch object + self.test_handler.assert_called_once() + call_args = self.test_handler.call_args[0][0] + self.assertIsInstance(call_args, PatternMatch) + self.assertEqual(call_args.pattern_id, "test_pattern") + self.assertEqual(call_args.full_match, "pattern content") + self.assertEqual(call_args.content, "pattern content") + + # The exclamation point should be treated as a sentence boundary, + # so the result should include just text up to and including "!" + self.assertEqual(result, "Hello !") + + # Next sentence should be processed separately + result = self.aggregator.aggregate(" This is another sentence.") + self.assertEqual(result, " This is another sentence.") + + # Buffer should be empty after returning a complete sentence + self.assertEqual(self.aggregator.text, "") + + async def test_incomplete_pattern(self): + # Add text with incomplete pattern + result = self.aggregator.aggregate("Hello pattern content") + + # No complete pattern yet, so nothing should be returned + self.assertIsNone(result) + + # The handler should not be called yet + self.test_handler.assert_not_called() + + # Buffer should contain the incomplete text + self.assertEqual(self.aggregator.text, "Hello pattern content") + + # Reset and confirm buffer is cleared + self.aggregator.reset() + self.assertEqual(self.aggregator.text, "") + + async def test_multiple_patterns(self): + # Set up multiple patterns and handlers + voice_handler = Mock() + emphasis_handler = Mock() + + self.aggregator.add_pattern_pair( + pattern_id="voice", start_pattern="", end_pattern="", remove_match=True + ) + + self.aggregator.add_pattern_pair( + pattern_id="emphasis", + start_pattern="", + end_pattern="", + remove_match=False, # Keep emphasis tags + ) + + self.aggregator.on_pattern_match("voice", voice_handler) + self.aggregator.on_pattern_match("emphasis", emphasis_handler) + + # Test with multiple patterns in one text block + text = "Hello female I am very excited to meet you!" + result = self.aggregator.aggregate(text) + + # Both handlers should be called with correct data + voice_handler.assert_called_once() + voice_match = voice_handler.call_args[0][0] + self.assertEqual(voice_match.pattern_id, "voice") + self.assertEqual(voice_match.content, "female") + + emphasis_handler.assert_called_once() + emphasis_match = emphasis_handler.call_args[0][0] + self.assertEqual(emphasis_match.pattern_id, "emphasis") + self.assertEqual(emphasis_match.content, "very") + + # Voice pattern should be removed, emphasis pattern should remain + self.assertEqual(result, "Hello I am very excited to meet you!") + + # Buffer should be empty + self.assertEqual(self.aggregator.text, "") + + async def test_handle_interruption(self): + # Start with incomplete pattern + result = self.aggregator.aggregate("Hello pattern") + self.assertIsNone(result) + + # Simulate interruption + self.aggregator.handle_interruption() + + # Buffer should be cleared + self.assertEqual(self.aggregator.text, "") + + # Handler should not have been called + self.test_handler.assert_not_called() + + async def test_pattern_across_sentences(self): + # Test pattern that spans multiple sentences + result = self.aggregator.aggregate("Hello This is sentence one.") + + # First sentence contains start of pattern but no end, so no complete pattern yet + self.assertIsNone(result) + + # Add second part with pattern end + result = self.aggregator.aggregate(" This is sentence two. Final sentence.") + + # Handler should be called with entire content + self.test_handler.assert_called_once() + call_args = self.test_handler.call_args[0][0] + self.assertEqual(call_args.content, "This is sentence one. This is sentence two.") + + # Pattern should be removed, resulting in text with sentences merged + self.assertEqual(result, "Hello Final sentence.") + + # Buffer should be empty + self.assertEqual(self.aggregator.text, "")