diff --git a/examples/foundational/35-pattern-pair-voice-switching.py b/examples/foundational/35-pattern-pair-voice-switching.py
index 871f74f8d..e73435d72 100644
--- a/examples/foundational/35-pattern-pair-voice-switching.py
+++ b/examples/foundational/35-pattern-pair-voice-switching.py
@@ -111,10 +111,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Add pattern for voice switching
pattern_aggregator.add_pattern_pair(
- pattern_id="voice_tag",
+ type="voice",
start_pattern="",
end_pattern="",
- type="voice",
action=MatchAction.REMOVE, # Remove tags from final text
)
@@ -130,7 +129,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
else:
logger.warning(f"Unknown voice: {voice_name}")
- pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
+ pattern_aggregator.on_pattern_match("voice", on_voice_tag)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
diff --git a/src/pipecat/extensions/ivr/ivr_navigator.py b/src/pipecat/extensions/ivr/ivr_navigator.py
index 4b46c1ffb..7c0cb87aa 100644
--- a/src/pipecat/extensions/ivr/ivr_navigator.py
+++ b/src/pipecat/extensions/ivr/ivr_navigator.py
@@ -118,21 +118,15 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
- self._aggregator.add_pattern_pair(
- "dtmf", "", "", type="dtmf", action=MatchAction.REMOVE
- )
+ self._aggregator.add_pattern_pair("dtmf", "", "", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
- self._aggregator.add_pattern_pair(
- "mode", "", "", type="mode", action=MatchAction.REMOVE
- )
+ self._aggregator.add_pattern_pair("mode", "", "", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
- self._aggregator.add_pattern_pair(
- "ivr", "", "", type="ivr", action=MatchAction.REMOVE
- )
+ self._aggregator.add_pattern_pair("ivr", "", "", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):
diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py
index 0551f7cc3..a31ad2390 100644
--- a/src/pipecat/utils/text/pattern_pair_aggregator.py
+++ b/src/pipecat/utils/text/pattern_pair_aggregator.py
@@ -44,27 +44,25 @@ class PatternMatch(Aggregation):
content between the patterns.
"""
- def __init__(self, pattern_id: str, full_match: str, content: str, type: str):
+ def __init__(self, content: str, type: str, full_match: str):
"""Initialize a pattern match.
Args:
- pattern_id: The identifier of the matched pattern pair.
+ type: The type of the matched pattern pair. It should be representative
+ of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
full_match: The complete text including start and end patterns.
content: The text content between the start and end patterns.
- type: The type of aggregation the matched content represents
- (e.g., 'code', 'speaker', 'custom').
"""
super().__init__(text=content, type=type)
- self.pattern_id = pattern_id
self.full_match = full_match
def __str__(self) -> str:
"""Return a string representation of the pattern match.
Returns:
- A descriptive string showing the pattern ID and content.
+ A descriptive string showing the pattern type and content.
"""
- return f"PatternMatch(id={self.pattern_id}, content={self.text}, full_match={self.full_match}, type={self.type})"
+ return f"PatternMatch(type={self.type}, content={self.text}, full_match={self.full_match})"
class PatternPairAggregator(BaseTextAggregator):
@@ -110,10 +108,9 @@ class PatternPairAggregator(BaseTextAggregator):
def add_pattern_pair(
self,
- pattern_id: str,
+ type: str,
start_pattern: str,
end_pattern: str,
- type: str,
action: MatchAction = MatchAction.REMOVE,
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -123,11 +120,11 @@ class PatternPairAggregator(BaseTextAggregator):
the end pattern, and treat the content between them as a match.
Args:
- pattern_id: Unique identifier for this pattern pair.
+ type: Identifier for this pattern pair. Should be unique and ideally descriptive.
+ (e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' as that is
+ reserved for the default behavior.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
- type: The type of aggregation the matched content represents
- (e.g., 'code', 'speaker', 'custom').
action: What to do when a complete pattern is matched:
- MatchAction.REMOVE: Remove the matched pattern from the text.
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
@@ -139,7 +136,11 @@ class PatternPairAggregator(BaseTextAggregator):
Returns:
Self for method chaining.
"""
- self._patterns[pattern_id] = {
+ if type == "sentence":
+ raise ValueError(
+ "The aggregation type 'sentence' is reserved for default behavior and can not be used for custom patterns."
+ )
+ self._patterns[type] = {
"start": start_pattern,
"end": end_pattern,
"type": type,
@@ -148,22 +149,22 @@ class PatternPairAggregator(BaseTextAggregator):
return self
def on_pattern_match(
- self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
+ self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
) -> "PatternPairAggregator":
"""Register a handler for when a pattern pair is matched.
The handler will be called whenever a complete match for the
- specified pattern ID is found in the text.
+ specified type is found in the text.
Args:
- pattern_id: ID of the pattern pair to match.
+ type: The type of the pattern pair to trigger the handler.
handler: Async function to call when pattern is matched.
The function should accept a PatternMatch object.
Returns:
Self for method chaining.
"""
- self._handlers[pattern_id] = handler
+ self._handlers[type] = handler
return self
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
@@ -184,12 +185,11 @@ class PatternPairAggregator(BaseTextAggregator):
all_matches = []
processed_text = text
- for pattern_id, pattern_info in self._patterns.items():
+ for type, pattern_info in self._patterns.items():
# Escape special regex characters in the patterns
start = re.escape(pattern_info["start"])
end = re.escape(pattern_info["end"])
action = pattern_info["action"]
- match_type = pattern_info["type"]
# Create regex to match from start pattern to end pattern
# The .*? is non-greedy to handle nested patterns
@@ -204,16 +204,14 @@ class PatternPairAggregator(BaseTextAggregator):
full_match = match.group(0) # Full match including patterns
# Create pattern match object
- pattern_match = PatternMatch(
- pattern_id=pattern_id, full_match=full_match, content=content, type=match_type
- )
+ pattern_match = PatternMatch(content=content, type=type, full_match=full_match)
# Call the appropriate handler if registered
- if pattern_id in self._handlers:
+ if type in self._handlers:
try:
- await self._handlers[pattern_id](pattern_match)
+ await self._handlers[type](pattern_match)
except Exception as e:
- logger.error(f"Error in pattern handler for {pattern_id}: {e}")
+ logger.error(f"Error in pattern handler for {type}: {e}")
# Remove the pattern from the text if configured
if action == MatchAction.REMOVE:
@@ -233,10 +231,10 @@ class PatternPairAggregator(BaseTextAggregator):
text: The text to check for incomplete patterns.
Returns:
- A tuple of (start_index, type) if an incomplete pattern is found,
+ A tuple of (start_index, pattern_info) if an incomplete pattern is found,
or None if no patterns are found or all patterns are complete.
"""
- for pattern_id, pattern_info in self._patterns.items():
+ for type, pattern_info in self._patterns.items():
start = pattern_info["start"]
end = pattern_info["end"]
@@ -280,10 +278,10 @@ class PatternPairAggregator(BaseTextAggregator):
if len(patterns) > 0:
if len(patterns) > 1:
logger.warning(
- f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
+ f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
- action = self._patterns[patterns[0].pattern_id].get("action", MatchAction.REMOVE)
+ action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
return patterns[0]
@@ -300,7 +298,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
- return PatternMatch(f"_sentence", result, result, "sentence")
+ return PatternMatch(content=result, type="sentence", full_match=result)
# Find sentence boundary if no incomplete patterns
eos_marker = match_endofsentence(self._text)
@@ -308,7 +306,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
- return PatternMatch(f"_sentence", result, result, "sentence")
+ return PatternMatch(content=result, type="sentence", full_match=result)
# No complete sentence found yet
return None
diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py
index 5887a8197..ffbff4fbb 100644
--- a/tests/test_pattern_pair_aggregator.py
+++ b/tests/test_pattern_pair_aggregator.py
@@ -22,17 +22,15 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Add a test pattern
self.aggregator.add_pattern_pair(
- pattern_id="test_pattern",
+ type="test_pattern",
start_pattern="",
end_pattern="",
- type="test",
action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
- pattern_id="code_pattern",
+ type="code_pattern",
start_pattern="",
end_pattern="",
- type="code",
action=MatchAction.AGGREGATE,
)
@@ -45,7 +43,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
result = await self.aggregator.aggregate("Hello pattern")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "Hello pattern")
- self.assertEqual(self.aggregator.text.type, "test")
+ self.assertEqual(self.aggregator.text.type, "test_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content!")
@@ -54,7 +52,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
- self.assertEqual(call_args.pattern_id, "test_pattern")
+ self.assertEqual(call_args.type, "test_pattern")
self.assertEqual(call_args.full_match, "pattern content")
self.assertEqual(call_args.text, "pattern content")
@@ -75,7 +73,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
result = await self.aggregator.aggregate("Here is code pattern")
self.assertEqual(result.text, "Here is code ")
self.assertEqual(self.aggregator.text.text, "pattern")
- self.assertEqual(self.aggregator.text.type, "code")
+ self.assertEqual(self.aggregator.text.type, "code_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content")
@@ -84,11 +82,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.code_handler.assert_called_once()
call_args = self.code_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
- self.assertEqual(call_args.pattern_id, "code_pattern")
+ self.assertEqual(call_args.type, "code_pattern")
self.assertEqual(call_args.full_match, "pattern content")
self.assertEqual(call_args.text, "pattern content")
self.assertEqual(result.text, "pattern content")
- self.assertEqual(result.type, "code")
+ self.assertEqual(result.type, "code_pattern")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
@@ -110,7 +108,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Buffer should contain the incomplete text
self.assertEqual(self.aggregator.text.text, "Hello pattern content")
- self.assertEqual(self.aggregator.text.type, "test")
+ self.assertEqual(self.aggregator.text.type, "test_pattern")
# Reset and confirm buffer is cleared
await self.aggregator.reset()
@@ -122,18 +120,16 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
emphasis_handler = AsyncMock()
self.aggregator.add_pattern_pair(
- pattern_id="voice",
+ type="voice",
start_pattern="",
end_pattern="",
- type="voice",
action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
- pattern_id="emphasis",
+ type="emphasis",
start_pattern="",
end_pattern="",
- type="emphasis",
action=MatchAction.KEEP, # Keep emphasis tags
)
@@ -147,12 +143,12 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Both handlers should be called with correct data
voice_handler.assert_called_once()
voice_match = voice_handler.call_args[0][0]
- self.assertEqual(voice_match.pattern_id, "voice")
+ self.assertEqual(voice_match.type, "voice")
self.assertEqual(voice_match.text, "female")
emphasis_handler.assert_called_once()
emphasis_match = emphasis_handler.call_args[0][0]
- self.assertEqual(emphasis_match.pattern_id, "emphasis")
+ self.assertEqual(emphasis_match.type, "emphasis")
self.assertEqual(emphasis_match.text, "very")
# Voice pattern should be removed, emphasis pattern should remain