diff --git a/examples/foundational/35-pattern-pair-voice-switching.py b/examples/foundational/35-pattern-pair-voice-switching.py index 550520317..871f74f8d 100644 --- a/examples/foundational/35-pattern-pair-voice-switching.py +++ b/examples/foundational/35-pattern-pair-voice-switching.py @@ -62,7 +62,11 @@ from pipecat.services.openai.llm import OpenAILLMService from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams -from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator +from pipecat.utils.text.pattern_pair_aggregator import ( + MatchAction, + PatternMatch, + PatternPairAggregator, +) load_dotenv(override=True) @@ -111,7 +115,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): start_pattern="", end_pattern="", type="voice", - action="remove", # Remove tags from final text + action=MatchAction.REMOVE, # Remove tags from final text ) # Register handler for voice switching diff --git a/src/pipecat/extensions/ivr/ivr_navigator.py b/src/pipecat/extensions/ivr/ivr_navigator.py index fac760ab3..4b46c1ffb 100644 --- a/src/pipecat/extensions/ivr/ivr_navigator.py +++ b/src/pipecat/extensions/ivr/ivr_navigator.py @@ -31,7 +31,11 @@ from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.services.llm_service import LLMService -from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator +from pipecat.utils.text.pattern_pair_aggregator import ( + MatchAction, + PatternMatch, + PatternPairAggregator, +) class IVRStatus(Enum): @@ -114,15 +118,21 @@ class IVRProcessor(FrameProcessor): def _setup_xml_patterns(self): """Set up XML pattern detection and handlers.""" # Register DTMF pattern - self._aggregator.add_pattern_pair("dtmf", "", "", type="dtmf", action="remove") + self._aggregator.add_pattern_pair( + "dtmf", "", "", type="dtmf", action=MatchAction.REMOVE + ) self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action) # Register mode pattern - self._aggregator.add_pattern_pair("mode", "", "", type="mode", action="remove") + self._aggregator.add_pattern_pair( + "mode", "", "", type="mode", action=MatchAction.REMOVE + ) self._aggregator.on_pattern_match("mode", self._handle_mode_action) # Register IVR pattern - self._aggregator.add_pattern_pair("ivr", "", "", type="ivr", action="remove") + self._aggregator.add_pattern_pair( + "ivr", "", "", type="ivr", action=MatchAction.REMOVE + ) self._aggregator.on_pattern_match("ivr", self._handle_ivr_action) async def process_frame(self, frame: Frame, direction: FrameDirection): diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py index d4712a823..f5e33f5b8 100644 --- a/src/pipecat/utils/text/pattern_pair_aggregator.py +++ b/src/pipecat/utils/text/pattern_pair_aggregator.py @@ -12,7 +12,8 @@ support for custom handlers and configurable pattern removal. """ import re -from typing import Awaitable, Callable, List, Literal, Optional, Tuple +from enum import Enum +from typing import Awaitable, Callable, List, Optional, Tuple from loguru import logger @@ -20,6 +21,20 @@ from pipecat.utils.string import match_endofsentence from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator +class MatchAction(Enum): + """Actions to take when a pattern pair is matched. + + Parameters: + REMOVE: Remove the matched pattern from the text. + KEEP: Keep the matched pattern in the text as normal text. + AGGREGATE: Return the matched pattern as a separate aggregation object. + """ + + REMOVE = "remove" + KEEP = "keep" + AGGREGATE = "aggregate" + + class PatternMatch(Aggregation): """Represents a matched pattern pair with its content. @@ -94,7 +109,7 @@ class PatternPairAggregator(BaseTextAggregator): start_pattern: str, end_pattern: str, type: str, - action: Literal["remove", "keep", "aggregate"] = "remove", + action: MatchAction = MatchAction.REMOVE, ) -> "PatternPairAggregator": """Add a pattern pair to detect in the text. @@ -109,12 +124,12 @@ class PatternPairAggregator(BaseTextAggregator): type: The type of aggregation the matched content represents (e.g., 'code', 'speaker', 'custom'). action: What to do when a complete pattern is matched: - - "remove": Remove the matched pattern from the text. - - "keep": Keep the matched pattern in the text and treat it as - normal text. This allows you to register handlers for - the pattern without affecting the aggregation logic. - - "aggregate": Return the matched pattern as a separate - aggregation object. + - MatchAction.REMOVE: Remove the matched pattern from the text. + - MatchAction.KEEP: Keep the matched pattern in the text and treat it as + normal text. This allows you to register handlers for + the pattern without affecting the aggregation logic. + - MatchAction.AGGREGATE: Return the matched pattern as a separate + aggregation object. Returns: Self for method chaining. @@ -196,7 +211,7 @@ class PatternPairAggregator(BaseTextAggregator): logger.error(f"Error in pattern handler for {pattern_id}: {e}") # Remove the pattern from the text if configured - if action == "remove": + if action == MatchAction.REMOVE: processed_text = processed_text.replace(full_match, "", 1) # modified = True else: @@ -260,15 +275,13 @@ class PatternPairAggregator(BaseTextAggregator): # if len(patterns) > 0: - print(f"Found patterns: {[str(p) for p in patterns]}") if len(patterns) > 1: logger.warning( f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned." ) # If the pattern found is set to be aggregated, return it - action = self._patterns[patterns[0].pattern_id].get("action", "remove") - print(f"Pattern action: {action}") - if action == "aggregate": + action = self._patterns[patterns[0].pattern_id].get("action", MatchAction.REMOVE) + if action == MatchAction.AGGREGATE: self._text = "" print(f"Returning pattern: {patterns[0]}") return patterns[0] @@ -277,7 +290,10 @@ class PatternPairAggregator(BaseTextAggregator): pattern_start = self._match_start_of_pattern(self._text) if pattern_start is not None: # If the start pattern is at the beginning or should not be separately aggregated, return None - if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate": + if ( + pattern_start[0] == 0 + or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE + ): return None # Otherwise, strip the text up to the start pattern and return it result = self._text[: pattern_start[0]] diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py index 310ccf635..804128eb4 100644 --- a/tests/test_pattern_pair_aggregator.py +++ b/tests/test_pattern_pair_aggregator.py @@ -7,7 +7,11 @@ import unittest from unittest.mock import AsyncMock -from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator +from pipecat.utils.text.pattern_pair_aggregator import ( + MatchAction, + PatternMatch, + PatternPairAggregator, +) class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): @@ -22,14 +26,14 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): start_pattern="", end_pattern="", type="test", - action="remove", + action=MatchAction.REMOVE, ) self.aggregator.add_pattern_pair( pattern_id="code_pattern", start_pattern="", end_pattern="", type="code", - action="aggregate", + action=MatchAction.AGGREGATE, ) # Register the mock handler @@ -122,7 +126,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): start_pattern="", end_pattern="", type="voice", - action="remove", + action=MatchAction.REMOVE, ) self.aggregator.add_pattern_pair( @@ -130,7 +134,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): start_pattern="", end_pattern="", type="emphasis", - action="keep", # Keep emphasis tags + action=MatchAction.KEEP, # Keep emphasis tags ) self.aggregator.on_pattern_match("voice", voice_handler)