From e9de9daf8cfca82369007ccb2a7973d5a3486e9e Mon Sep 17 00:00:00 2001 From: mattie ruth backman Date: Tue, 4 Nov 2025 16:22:34 -0500 Subject: [PATCH] Update PatternPairAggregator patterns to replace pattern_id with type to simplify the API --- .../35-pattern-pair-voice-switching.py | 5 +- src/pipecat/extensions/ivr/ivr_navigator.py | 12 +--- .../utils/text/pattern_pair_aggregator.py | 60 +++++++++---------- tests/test_pattern_pair_aggregator.py | 28 ++++----- 4 files changed, 46 insertions(+), 59 deletions(-) diff --git a/examples/foundational/35-pattern-pair-voice-switching.py b/examples/foundational/35-pattern-pair-voice-switching.py index 871f74f8d..e73435d72 100644 --- a/examples/foundational/35-pattern-pair-voice-switching.py +++ b/examples/foundational/35-pattern-pair-voice-switching.py @@ -111,10 +111,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): # Add pattern for voice switching pattern_aggregator.add_pattern_pair( - pattern_id="voice_tag", + type="voice", start_pattern="", end_pattern="", - type="voice", action=MatchAction.REMOVE, # Remove tags from final text ) @@ -130,7 +129,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): else: logger.warning(f"Unknown voice: {voice_name}") - pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag) + pattern_aggregator.on_pattern_match("voice", on_voice_tag) stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) diff --git a/src/pipecat/extensions/ivr/ivr_navigator.py b/src/pipecat/extensions/ivr/ivr_navigator.py index 4b46c1ffb..7c0cb87aa 100644 --- a/src/pipecat/extensions/ivr/ivr_navigator.py +++ b/src/pipecat/extensions/ivr/ivr_navigator.py @@ -118,21 +118,15 @@ class IVRProcessor(FrameProcessor): def _setup_xml_patterns(self): """Set up XML pattern detection and handlers.""" # Register DTMF pattern - self._aggregator.add_pattern_pair( - "dtmf", "", "", type="dtmf", action=MatchAction.REMOVE - ) + self._aggregator.add_pattern_pair("dtmf", "", "", action=MatchAction.REMOVE) self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action) # Register mode pattern - self._aggregator.add_pattern_pair( - "mode", "", "", type="mode", action=MatchAction.REMOVE - ) + self._aggregator.add_pattern_pair("mode", "", "", action=MatchAction.REMOVE) self._aggregator.on_pattern_match("mode", self._handle_mode_action) # Register IVR pattern - self._aggregator.add_pattern_pair( - "ivr", "", "", type="ivr", action=MatchAction.REMOVE - ) + self._aggregator.add_pattern_pair("ivr", "", "", action=MatchAction.REMOVE) self._aggregator.on_pattern_match("ivr", self._handle_ivr_action) async def process_frame(self, frame: Frame, direction: FrameDirection): diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py index 0551f7cc3..a31ad2390 100644 --- a/src/pipecat/utils/text/pattern_pair_aggregator.py +++ b/src/pipecat/utils/text/pattern_pair_aggregator.py @@ -44,27 +44,25 @@ class PatternMatch(Aggregation): content between the patterns. """ - def __init__(self, pattern_id: str, full_match: str, content: str, type: str): + def __init__(self, content: str, type: str, full_match: str): """Initialize a pattern match. Args: - pattern_id: The identifier of the matched pattern pair. + type: The type of the matched pattern pair. It should be representative + of the content type (e.g., 'sentence', 'code', 'speaker', 'custom'). full_match: The complete text including start and end patterns. content: The text content between the start and end patterns. - type: The type of aggregation the matched content represents - (e.g., 'code', 'speaker', 'custom'). """ super().__init__(text=content, type=type) - self.pattern_id = pattern_id self.full_match = full_match def __str__(self) -> str: """Return a string representation of the pattern match. Returns: - A descriptive string showing the pattern ID and content. + A descriptive string showing the pattern type and content. """ - return f"PatternMatch(id={self.pattern_id}, content={self.text}, full_match={self.full_match}, type={self.type})" + return f"PatternMatch(type={self.type}, content={self.text}, full_match={self.full_match})" class PatternPairAggregator(BaseTextAggregator): @@ -110,10 +108,9 @@ class PatternPairAggregator(BaseTextAggregator): def add_pattern_pair( self, - pattern_id: str, + type: str, start_pattern: str, end_pattern: str, - type: str, action: MatchAction = MatchAction.REMOVE, ) -> "PatternPairAggregator": """Add a pattern pair to detect in the text. @@ -123,11 +120,11 @@ class PatternPairAggregator(BaseTextAggregator): the end pattern, and treat the content between them as a match. Args: - pattern_id: Unique identifier for this pattern pair. + type: Identifier for this pattern pair. Should be unique and ideally descriptive. + (e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' as that is + reserved for the default behavior. start_pattern: Pattern that marks the beginning of content. end_pattern: Pattern that marks the end of content. - type: The type of aggregation the matched content represents - (e.g., 'code', 'speaker', 'custom'). action: What to do when a complete pattern is matched: - MatchAction.REMOVE: Remove the matched pattern from the text. - MatchAction.KEEP: Keep the matched pattern in the text and treat it as @@ -139,7 +136,11 @@ class PatternPairAggregator(BaseTextAggregator): Returns: Self for method chaining. """ - self._patterns[pattern_id] = { + if type == "sentence": + raise ValueError( + "The aggregation type 'sentence' is reserved for default behavior and can not be used for custom patterns." + ) + self._patterns[type] = { "start": start_pattern, "end": end_pattern, "type": type, @@ -148,22 +149,22 @@ class PatternPairAggregator(BaseTextAggregator): return self def on_pattern_match( - self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]] + self, type: str, handler: Callable[[PatternMatch], Awaitable[None]] ) -> "PatternPairAggregator": """Register a handler for when a pattern pair is matched. The handler will be called whenever a complete match for the - specified pattern ID is found in the text. + specified type is found in the text. Args: - pattern_id: ID of the pattern pair to match. + type: The type of the pattern pair to trigger the handler. handler: Async function to call when pattern is matched. The function should accept a PatternMatch object. Returns: Self for method chaining. """ - self._handlers[pattern_id] = handler + self._handlers[type] = handler return self async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]: @@ -184,12 +185,11 @@ class PatternPairAggregator(BaseTextAggregator): all_matches = [] processed_text = text - for pattern_id, pattern_info in self._patterns.items(): + for type, pattern_info in self._patterns.items(): # Escape special regex characters in the patterns start = re.escape(pattern_info["start"]) end = re.escape(pattern_info["end"]) action = pattern_info["action"] - match_type = pattern_info["type"] # Create regex to match from start pattern to end pattern # The .*? is non-greedy to handle nested patterns @@ -204,16 +204,14 @@ class PatternPairAggregator(BaseTextAggregator): full_match = match.group(0) # Full match including patterns # Create pattern match object - pattern_match = PatternMatch( - pattern_id=pattern_id, full_match=full_match, content=content, type=match_type - ) + pattern_match = PatternMatch(content=content, type=type, full_match=full_match) # Call the appropriate handler if registered - if pattern_id in self._handlers: + if type in self._handlers: try: - await self._handlers[pattern_id](pattern_match) + await self._handlers[type](pattern_match) except Exception as e: - logger.error(f"Error in pattern handler for {pattern_id}: {e}") + logger.error(f"Error in pattern handler for {type}: {e}") # Remove the pattern from the text if configured if action == MatchAction.REMOVE: @@ -233,10 +231,10 @@ class PatternPairAggregator(BaseTextAggregator): text: The text to check for incomplete patterns. Returns: - A tuple of (start_index, type) if an incomplete pattern is found, + A tuple of (start_index, pattern_info) if an incomplete pattern is found, or None if no patterns are found or all patterns are complete. """ - for pattern_id, pattern_info in self._patterns.items(): + for type, pattern_info in self._patterns.items(): start = pattern_info["start"] end = pattern_info["end"] @@ -280,10 +278,10 @@ class PatternPairAggregator(BaseTextAggregator): if len(patterns) > 0: if len(patterns) > 1: logger.warning( - f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned." + f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned." ) # If the pattern found is set to be aggregated, return it - action = self._patterns[patterns[0].pattern_id].get("action", MatchAction.REMOVE) + action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE) if action == MatchAction.AGGREGATE: self._text = "" return patterns[0] @@ -300,7 +298,7 @@ class PatternPairAggregator(BaseTextAggregator): # Otherwise, strip the text up to the start pattern and return it result = self._text[: pattern_start[0]] self._text = self._text[pattern_start[0] :] - return PatternMatch(f"_sentence", result, result, "sentence") + return PatternMatch(content=result, type="sentence", full_match=result) # Find sentence boundary if no incomplete patterns eos_marker = match_endofsentence(self._text) @@ -308,7 +306,7 @@ class PatternPairAggregator(BaseTextAggregator): # Extract text up to the sentence boundary result = self._text[:eos_marker] self._text = self._text[eos_marker:] - return PatternMatch(f"_sentence", result, result, "sentence") + return PatternMatch(content=result, type="sentence", full_match=result) # No complete sentence found yet return None diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py index 5887a8197..ffbff4fbb 100644 --- a/tests/test_pattern_pair_aggregator.py +++ b/tests/test_pattern_pair_aggregator.py @@ -22,17 +22,15 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): # Add a test pattern self.aggregator.add_pattern_pair( - pattern_id="test_pattern", + type="test_pattern", start_pattern="", end_pattern="", - type="test", action=MatchAction.REMOVE, ) self.aggregator.add_pattern_pair( - pattern_id="code_pattern", + type="code_pattern", start_pattern="", end_pattern="", - type="code", action=MatchAction.AGGREGATE, ) @@ -45,7 +43,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): result = await self.aggregator.aggregate("Hello pattern") self.assertIsNone(result) self.assertEqual(self.aggregator.text.text, "Hello pattern") - self.assertEqual(self.aggregator.text.type, "test") + self.assertEqual(self.aggregator.text.type, "test_pattern") # Second part completes the pattern and includes an exclamation point result = await self.aggregator.aggregate(" content!") @@ -54,7 +52,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): self.test_handler.assert_called_once() call_args = self.test_handler.call_args[0][0] self.assertIsInstance(call_args, PatternMatch) - self.assertEqual(call_args.pattern_id, "test_pattern") + self.assertEqual(call_args.type, "test_pattern") self.assertEqual(call_args.full_match, "pattern content") self.assertEqual(call_args.text, "pattern content") @@ -75,7 +73,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): result = await self.aggregator.aggregate("Here is code pattern") self.assertEqual(result.text, "Here is code ") self.assertEqual(self.aggregator.text.text, "pattern") - self.assertEqual(self.aggregator.text.type, "code") + self.assertEqual(self.aggregator.text.type, "code_pattern") # Second part completes the pattern and includes an exclamation point result = await self.aggregator.aggregate(" content") @@ -84,11 +82,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): self.code_handler.assert_called_once() call_args = self.code_handler.call_args[0][0] self.assertIsInstance(call_args, PatternMatch) - self.assertEqual(call_args.pattern_id, "code_pattern") + self.assertEqual(call_args.type, "code_pattern") self.assertEqual(call_args.full_match, "pattern content") self.assertEqual(call_args.text, "pattern content") self.assertEqual(result.text, "pattern content") - self.assertEqual(result.type, "code") + self.assertEqual(result.type, "code_pattern") # Next sentence should be processed separately result = await self.aggregator.aggregate(" This is another sentence.") @@ -110,7 +108,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): # Buffer should contain the incomplete text self.assertEqual(self.aggregator.text.text, "Hello pattern content") - self.assertEqual(self.aggregator.text.type, "test") + self.assertEqual(self.aggregator.text.type, "test_pattern") # Reset and confirm buffer is cleared await self.aggregator.reset() @@ -122,18 +120,16 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): emphasis_handler = AsyncMock() self.aggregator.add_pattern_pair( - pattern_id="voice", + type="voice", start_pattern="", end_pattern="", - type="voice", action=MatchAction.REMOVE, ) self.aggregator.add_pattern_pair( - pattern_id="emphasis", + type="emphasis", start_pattern="", end_pattern="", - type="emphasis", action=MatchAction.KEEP, # Keep emphasis tags ) @@ -147,12 +143,12 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): # Both handlers should be called with correct data voice_handler.assert_called_once() voice_match = voice_handler.call_args[0][0] - self.assertEqual(voice_match.pattern_id, "voice") + self.assertEqual(voice_match.type, "voice") self.assertEqual(voice_match.text, "female") emphasis_handler.assert_called_once() emphasis_match = emphasis_handler.call_args[0][0] - self.assertEqual(emphasis_match.pattern_id, "emphasis") + self.assertEqual(emphasis_match.type, "emphasis") self.assertEqual(emphasis_match.text, "very") # Voice pattern should be removed, emphasis pattern should remain