diff --git a/examples/foundational/35-pattern-pair-voice-switching.py b/examples/foundational/35-pattern-pair-voice-switching.py index 013c3eecd..550520317 100644 --- a/examples/foundational/35-pattern-pair-voice-switching.py +++ b/examples/foundational/35-pattern-pair-voice-switching.py @@ -111,7 +111,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): start_pattern="", end_pattern="", type="voice", - remove_match=True, + action="remove", # Remove tags from final text ) # Register handler for voice switching diff --git a/src/pipecat/extensions/ivr/ivr_navigator.py b/src/pipecat/extensions/ivr/ivr_navigator.py index c4e31e12a..fac760ab3 100644 --- a/src/pipecat/extensions/ivr/ivr_navigator.py +++ b/src/pipecat/extensions/ivr/ivr_navigator.py @@ -114,19 +114,15 @@ class IVRProcessor(FrameProcessor): def _setup_xml_patterns(self): """Set up XML pattern detection and handlers.""" # Register DTMF pattern - self._aggregator.add_pattern_pair( - "dtmf", "", "", type="dtmf", remove_match=True - ) + self._aggregator.add_pattern_pair("dtmf", "", "", type="dtmf", action="remove") self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action) # Register mode pattern - self._aggregator.add_pattern_pair( - "mode", "", "", type="mode", remove_match=True - ) + self._aggregator.add_pattern_pair("mode", "", "", type="mode", action="remove") self._aggregator.on_pattern_match("mode", self._handle_mode_action) # Register IVR pattern - self._aggregator.add_pattern_pair("ivr", "", "", type="ivr", remove_match=True) + self._aggregator.add_pattern_pair("ivr", "", "", type="ivr", action="remove") self._aggregator.on_pattern_match("ivr", self._handle_ivr_action) async def process_frame(self, frame: Frame, direction: FrameDirection): @@ -163,7 +159,7 @@ class IVRProcessor(FrameProcessor): Args: match: The pattern match containing DTMF content. """ - value = match.content + value = match.text logger.debug(f"DTMF detected: {value}") try: @@ -184,7 +180,7 @@ class IVRProcessor(FrameProcessor): Args: match: The pattern match containing IVR status content. """ - status = match.content + status = match.text logger.trace(f"IVR status detected: {status}") # Convert string to enum, with validation @@ -215,7 +211,7 @@ class IVRProcessor(FrameProcessor): Args: match: The pattern match containing mode content. """ - mode = match.content + mode = match.text logger.debug(f"Mode detected: {mode}") if mode == "conversation": await self._handle_conversation() diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py index fe32f5b51..c188c3974 100644 --- a/src/pipecat/utils/text/pattern_pair_aggregator.py +++ b/src/pipecat/utils/text/pattern_pair_aggregator.py @@ -12,7 +12,7 @@ support for custom handlers and configurable pattern removal. """ import re -from typing import Awaitable, Callable, List, Optional, Tuple +from typing import Awaitable, Callable, List, Literal, Optional, Tuple from loguru import logger @@ -83,9 +83,9 @@ class PatternPairAggregator(BaseTextAggregator): Returns: The text that has been accumulated in the buffer. """ - start, curtype = self._match_start_of_pattern(self._text) - if curtype: - return Aggregation(self._text, curtype) + pattern_start = self._match_start_of_pattern(self._text) + if pattern_start: + return Aggregation(self._text, pattern_start[1].get("type", "sentence")) return Aggregation(self._text, "sentence") def add_pattern_pair( @@ -94,7 +94,7 @@ class PatternPairAggregator(BaseTextAggregator): start_pattern: str, end_pattern: str, type: str, - remove_match: bool = True, + action: Literal["remove", "keep", "aggregate"] = "remove", ) -> "PatternPairAggregator": """Add a pattern pair to detect in the text. @@ -108,7 +108,12 @@ class PatternPairAggregator(BaseTextAggregator): end_pattern: Pattern that marks the end of content. type: The type of aggregation the matched content represents (e.g., 'code', 'speaker', 'custom'). - remove_match: Whether to remove the matched content from the text returned. + action: What to do when a complete pattern is matched: + - "remove": Remove the matched pattern from the text. + - "keep": Keep the matched pattern in the text and treat it as + normal text. + - "aggregate": Return the matched pattern as a separate + aggregation object. Returns: Self for method chaining. @@ -117,7 +122,7 @@ class PatternPairAggregator(BaseTextAggregator): "start": start_pattern, "end": end_pattern, "type": type, - "remove_match": remove_match, + "action": action, } return self @@ -162,7 +167,7 @@ class PatternPairAggregator(BaseTextAggregator): # Escape special regex characters in the patterns start = re.escape(pattern_info["start"]) end = re.escape(pattern_info["end"]) - remove_match = pattern_info["remove_match"] + action = pattern_info["action"] match_type = pattern_info["type"] # Create regex to match from start pattern to end pattern @@ -190,7 +195,7 @@ class PatternPairAggregator(BaseTextAggregator): logger.error(f"Error in pattern handler for {pattern_id}: {e}") # Remove the pattern from the text if configured - if remove_match: + if action == "remove": processed_text = processed_text.replace(full_match, "", 1) # modified = True else: @@ -198,7 +203,7 @@ class PatternPairAggregator(BaseTextAggregator): return all_matches, processed_text - def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, str]]: + def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]: """Check if text contains incomplete pattern pairs. Determines whether the text contains any start patterns without @@ -225,9 +230,9 @@ class PatternPairAggregator(BaseTextAggregator): # Which is why we base the return on the first found. if start_count > end_count: start_index = text.find(start) - return [start_index, pattern_info["type"]] + return [start_index, pattern_info] - return None, None + return None async def aggregate(self, text: str) -> Optional[PatternMatch]: """Aggregate text and process pattern pairs. @@ -258,17 +263,22 @@ class PatternPairAggregator(BaseTextAggregator): logger.warning( f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned." ) - self._text = "" - return patterns[0] + # If the pattern found is set to be aggregated, return it + action = self._patterns[patterns[0].pattern_id].get("action", "remove") + if action == "aggregate": + self._text = "" + print(f"Returning pattern: {patterns[0]}") + return patterns[0] # Check if we have incomplete patterns - start, curtype = self._match_start_of_pattern(self._text) - if start is not None: - # Still waiting for complete patterns - if start == 0: + pattern_start = self._match_start_of_pattern(self._text) + if pattern_start is not None: + # If the start pattern is at the beginning or should not be separately aggregated, return None + if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate": return None - result = self._text[:start] - self._text = self._text[start:] + # Otherwise, strip the text up to the start pattern and return it + result = self._text[: pattern_start[0]] + self._text = self._text[pattern_start[0] :] return PatternMatch(f"_sentence", result, result, "sentence") # Find sentence boundary if no incomplete patterns diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py index d0e13ffc6..310ccf635 100644 --- a/tests/test_pattern_pair_aggregator.py +++ b/tests/test_pattern_pair_aggregator.py @@ -14,6 +14,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): def setUp(self): self.aggregator = PatternPairAggregator() self.test_handler = AsyncMock() + self.code_handler = AsyncMock() # Add a test pattern self.aggregator.add_pattern_pair( @@ -21,22 +22,24 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): start_pattern="", end_pattern="", type="test", - remove_match=True, + action="remove", ) self.aggregator.add_pattern_pair( pattern_id="code_pattern", start_pattern="", end_pattern="", type="code", - remove_match=False, + action="aggregate", ) # Register the mock handler self.aggregator.on_pattern_match("test_pattern", self.test_handler) + self.aggregator.on_pattern_match("code_pattern", self.code_handler) async def test_pattern_match_and_removal(self): # First part doesn't complete the pattern result = await self.aggregator.aggregate("Hello pattern") + print(f"result: {result}") self.assertIsNone(result) self.assertEqual(self.aggregator.text.text, "Hello pattern") self.assertEqual(self.aggregator.text.type, "test") @@ -50,7 +53,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): self.assertIsInstance(call_args, PatternMatch) self.assertEqual(call_args.pattern_id, "test_pattern") self.assertEqual(call_args.full_match, "pattern content") - self.assertEqual(call_args.content, "pattern content") + self.assertEqual(call_args.text, "pattern content") # The exclamation point should be treated as a sentence boundary, # so the result should include just text up to and including "!" @@ -64,6 +67,33 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): # Buffer should be empty after returning a complete sentence self.assertEqual(self.aggregator.text.text, "") + async def test_pattern_match_and_aggregate(self): + # First part doesn't complete the pattern + result = await self.aggregator.aggregate("Here is code pattern") + print(f"result: {result}") + self.assertEqual(result.text, "Here is code ") + self.assertEqual(self.aggregator.text.text, "pattern") + self.assertEqual(self.aggregator.text.type, "code") + + # Second part completes the pattern and includes an exclamation point + result = await self.aggregator.aggregate(" content") + + # Verify the handler was called with correct PatternMatch object + self.code_handler.assert_called_once() + call_args = self.code_handler.call_args[0][0] + self.assertIsInstance(call_args, PatternMatch) + self.assertEqual(call_args.pattern_id, "code_pattern") + self.assertEqual(call_args.full_match, "pattern content") + self.assertEqual(call_args.text, "pattern content") + + # Next sentence should be processed separately + result = await self.aggregator.aggregate(" This is another sentence.") + self.assertEqual(result.text, " This is another sentence.") + self.assertEqual(result.type, "sentence") + + # Buffer should be empty after returning a complete sentence + self.assertEqual(self.aggregator.text.text, "") + async def test_incomplete_pattern(self): # Add text with incomplete pattern result = await self.aggregator.aggregate("Hello pattern content") @@ -88,14 +118,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): emphasis_handler = AsyncMock() self.aggregator.add_pattern_pair( - pattern_id="voice", start_pattern="", end_pattern="", remove_match=True + pattern_id="voice", + start_pattern="", + end_pattern="", + type="voice", + action="remove", ) self.aggregator.add_pattern_pair( pattern_id="emphasis", start_pattern="", end_pattern="", - remove_match=False, # Keep emphasis tags + type="emphasis", + action="keep", # Keep emphasis tags ) self.aggregator.on_pattern_match("voice", voice_handler) @@ -109,15 +144,15 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): voice_handler.assert_called_once() voice_match = voice_handler.call_args[0][0] self.assertEqual(voice_match.pattern_id, "voice") - self.assertEqual(voice_match.content, "female") + self.assertEqual(voice_match.text, "female") emphasis_handler.assert_called_once() emphasis_match = emphasis_handler.call_args[0][0] self.assertEqual(emphasis_match.pattern_id, "emphasis") - self.assertEqual(emphasis_match.content, "very") + self.assertEqual(emphasis_match.text, "very") # Voice pattern should be removed, emphasis pattern should remain - self.assertEqual(result, "Hello I am very excited to meet you!") + self.assertEqual(result.text, "Hello I am very excited to meet you!") # Buffer should be empty self.assertEqual(self.aggregator.text.text, "") @@ -149,10 +184,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): # Handler should be called with entire content self.test_handler.assert_called_once() call_args = self.test_handler.call_args[0][0] - self.assertEqual(call_args.content, "This is sentence one. This is sentence two.") + self.assertEqual(call_args.text, "This is sentence one. This is sentence two.") # Pattern should be removed, resulting in text with sentences merged - self.assertEqual(result, "Hello Final sentence.") + self.assertEqual(result.text, "Hello Final sentence.") # Buffer should be empty self.assertEqual(self.aggregator.text.text, "") diff --git a/tests/test_piper_tts.py b/tests/test_piper_tts.py index 75893f93f..209e6e76c 100644 --- a/tests/test_piper_tts.py +++ b/tests/test_piper_tts.py @@ -74,6 +74,7 @@ async def test_run_piper_tts_success(aiohttp_client): ] expected_returned_frames = [ + TTSTextFrame, TTSStartedFrame, TTSAudioRawFrame, TTSAudioRawFrame, @@ -121,7 +122,7 @@ async def test_run_piper_tts_error(aiohttp_client): TTSSpeakFrame(text="Error case."), ] - expected_down_frames = [TTSStoppedFrame, TTSTextFrame] + expected_down_frames = [TTSTextFrame, TTSStoppedFrame, TTSTextFrame] expected_up_frames = [ErrorFrame] diff --git a/tests/test_simple_text_aggregator.py b/tests/test_simple_text_aggregator.py index ff6dd1847..007549458 100644 --- a/tests/test_simple_text_aggregator.py +++ b/tests/test_simple_text_aggregator.py @@ -15,15 +15,20 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase): async def test_reset_aggregations(self): assert await self.aggregator.aggregate("Hello ") == None - assert self.aggregator.text == "Hello " + assert self.aggregator.text.text == "Hello " await self.aggregator.reset() - assert self.aggregator.text == "" + assert self.aggregator.text.text == "" async def test_simple_sentence(self): assert await self.aggregator.aggregate("Hello ") == None - assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!" - assert self.aggregator.text == "" + aggregate = await self.aggregator.aggregate("Pipecat!") + assert aggregate.text == "Hello Pipecat!" + assert aggregate.type == "sentence" + assert self.aggregator.text.text == "" async def test_multiple_sentences(self): - assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!" - assert await self.aggregator.aggregate("you?") == " How are you?" + aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ") + assert aggregate.text == "Hello Pipecat!" + assert self.aggregator.text.text == " How are " + aggregate = await self.aggregator.aggregate("you?") + assert aggregate.text == " How are you?"