diff --git a/examples/foundational/35-pattern-pair-voice-switching.py b/examples/foundational/35-pattern-pair-voice-switching.py
index 550520317..871f74f8d 100644
--- a/examples/foundational/35-pattern-pair-voice-switching.py
+++ b/examples/foundational/35-pattern-pair-voice-switching.py
@@ -62,7 +62,11 @@ from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
-from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
+from pipecat.utils.text.pattern_pair_aggregator import (
+ MatchAction,
+ PatternMatch,
+ PatternPairAggregator,
+)
load_dotenv(override=True)
@@ -111,7 +115,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
start_pattern="",
end_pattern="",
type="voice",
- action="remove", # Remove tags from final text
+ action=MatchAction.REMOVE, # Remove tags from final text
)
# Register handler for voice switching
diff --git a/src/pipecat/extensions/ivr/ivr_navigator.py b/src/pipecat/extensions/ivr/ivr_navigator.py
index fac760ab3..4b46c1ffb 100644
--- a/src/pipecat/extensions/ivr/ivr_navigator.py
+++ b/src/pipecat/extensions/ivr/ivr_navigator.py
@@ -31,7 +31,11 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import LLMService
-from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
+from pipecat.utils.text.pattern_pair_aggregator import (
+ MatchAction,
+ PatternMatch,
+ PatternPairAggregator,
+)
class IVRStatus(Enum):
@@ -114,15 +118,21 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
- self._aggregator.add_pattern_pair("dtmf", "", "", type="dtmf", action="remove")
+ self._aggregator.add_pattern_pair(
+ "dtmf", "", "", type="dtmf", action=MatchAction.REMOVE
+ )
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
- self._aggregator.add_pattern_pair("mode", "", "", type="mode", action="remove")
+ self._aggregator.add_pattern_pair(
+ "mode", "", "", type="mode", action=MatchAction.REMOVE
+ )
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
- self._aggregator.add_pattern_pair("ivr", "", "", type="ivr", action="remove")
+ self._aggregator.add_pattern_pair(
+ "ivr", "", "", type="ivr", action=MatchAction.REMOVE
+ )
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):
diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py
index d4712a823..f5e33f5b8 100644
--- a/src/pipecat/utils/text/pattern_pair_aggregator.py
+++ b/src/pipecat/utils/text/pattern_pair_aggregator.py
@@ -12,7 +12,8 @@ support for custom handlers and configurable pattern removal.
"""
import re
-from typing import Awaitable, Callable, List, Literal, Optional, Tuple
+from enum import Enum
+from typing import Awaitable, Callable, List, Optional, Tuple
from loguru import logger
@@ -20,6 +21,20 @@ from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator
+class MatchAction(Enum):
+ """Actions to take when a pattern pair is matched.
+
+ Parameters:
+ REMOVE: Remove the matched pattern from the text.
+ KEEP: Keep the matched pattern in the text as normal text.
+ AGGREGATE: Return the matched pattern as a separate aggregation object.
+ """
+
+ REMOVE = "remove"
+ KEEP = "keep"
+ AGGREGATE = "aggregate"
+
+
class PatternMatch(Aggregation):
"""Represents a matched pattern pair with its content.
@@ -94,7 +109,7 @@ class PatternPairAggregator(BaseTextAggregator):
start_pattern: str,
end_pattern: str,
type: str,
- action: Literal["remove", "keep", "aggregate"] = "remove",
+ action: MatchAction = MatchAction.REMOVE,
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -109,12 +124,12 @@ class PatternPairAggregator(BaseTextAggregator):
type: The type of aggregation the matched content represents
(e.g., 'code', 'speaker', 'custom').
action: What to do when a complete pattern is matched:
- - "remove": Remove the matched pattern from the text.
- - "keep": Keep the matched pattern in the text and treat it as
- normal text. This allows you to register handlers for
- the pattern without affecting the aggregation logic.
- - "aggregate": Return the matched pattern as a separate
- aggregation object.
+ - MatchAction.REMOVE: Remove the matched pattern from the text.
+ - MatchAction.KEEP: Keep the matched pattern in the text and treat it as
+ normal text. This allows you to register handlers for
+ the pattern without affecting the aggregation logic.
+ - MatchAction.AGGREGATE: Return the matched pattern as a separate
+ aggregation object.
Returns:
Self for method chaining.
@@ -196,7 +211,7 @@ class PatternPairAggregator(BaseTextAggregator):
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
# Remove the pattern from the text if configured
- if action == "remove":
+ if action == MatchAction.REMOVE:
processed_text = processed_text.replace(full_match, "", 1)
# modified = True
else:
@@ -260,15 +275,13 @@ class PatternPairAggregator(BaseTextAggregator):
#
if len(patterns) > 0:
- print(f"Found patterns: {[str(p) for p in patterns]}")
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
- action = self._patterns[patterns[0].pattern_id].get("action", "remove")
- print(f"Pattern action: {action}")
- if action == "aggregate":
+ action = self._patterns[patterns[0].pattern_id].get("action", MatchAction.REMOVE)
+ if action == MatchAction.AGGREGATE:
self._text = ""
print(f"Returning pattern: {patterns[0]}")
return patterns[0]
@@ -277,7 +290,10 @@ class PatternPairAggregator(BaseTextAggregator):
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, return None
- if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate":
+ if (
+ pattern_start[0] == 0
+ or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
+ ):
return None
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]
diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py
index 310ccf635..804128eb4 100644
--- a/tests/test_pattern_pair_aggregator.py
+++ b/tests/test_pattern_pair_aggregator.py
@@ -7,7 +7,11 @@
import unittest
from unittest.mock import AsyncMock
-from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
+from pipecat.utils.text.pattern_pair_aggregator import (
+ MatchAction,
+ PatternMatch,
+ PatternPairAggregator,
+)
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
@@ -22,14 +26,14 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
start_pattern="",
end_pattern="",
type="test",
- action="remove",
+ action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
pattern_id="code_pattern",
start_pattern="",
end_pattern="",
type="code",
- action="aggregate",
+ action=MatchAction.AGGREGATE,
)
# Register the mock handler
@@ -122,7 +126,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
start_pattern="",
end_pattern="",
type="voice",
- action="remove",
+ action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
@@ -130,7 +134,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
start_pattern="",
end_pattern="",
type="emphasis",
- action="keep", # Keep emphasis tags
+ action=MatchAction.KEEP, # Keep emphasis tags
)
self.aggregator.on_pattern_match("voice", voice_handler)