Make the PatternPair action an Enum

This commit is contained in:
mattie ruth backman
2025-10-28 16:43:06 -04:00
parent e6dc1a510d
commit ccca6e8d81
4 changed files with 59 additions and 25 deletions

View File

@@ -62,7 +62,11 @@ from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
load_dotenv(override=True)
@@ -111,7 +115,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
start_pattern="<voice>",
end_pattern="</voice>",
type="voice",
action="remove", # Remove tags from final text
action=MatchAction.REMOVE, # Remove tags from final text
)
# Register handler for voice switching

View File

@@ -31,7 +31,11 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import LLMService
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
class IVRStatus(Enum):
@@ -114,15 +118,21 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", type="dtmf", action="remove")
self._aggregator.add_pattern_pair(
"dtmf", "<dtmf>", "</dtmf>", type="dtmf", action=MatchAction.REMOVE
)
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", type="mode", action="remove")
self._aggregator.add_pattern_pair(
"mode", "<mode>", "</mode>", type="mode", action=MatchAction.REMOVE
)
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", type="ivr", action="remove")
self._aggregator.add_pattern_pair(
"ivr", "<ivr>", "</ivr>", type="ivr", action=MatchAction.REMOVE
)
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):

View File

@@ -12,7 +12,8 @@ support for custom handlers and configurable pattern removal.
"""
import re
from typing import Awaitable, Callable, List, Literal, Optional, Tuple
from enum import Enum
from typing import Awaitable, Callable, List, Optional, Tuple
from loguru import logger
@@ -20,6 +21,20 @@ from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator
class MatchAction(Enum):
"""Actions to take when a pattern pair is matched.
Parameters:
REMOVE: Remove the matched pattern from the text.
KEEP: Keep the matched pattern in the text as normal text.
AGGREGATE: Return the matched pattern as a separate aggregation object.
"""
REMOVE = "remove"
KEEP = "keep"
AGGREGATE = "aggregate"
class PatternMatch(Aggregation):
"""Represents a matched pattern pair with its content.
@@ -94,7 +109,7 @@ class PatternPairAggregator(BaseTextAggregator):
start_pattern: str,
end_pattern: str,
type: str,
action: Literal["remove", "keep", "aggregate"] = "remove",
action: MatchAction = MatchAction.REMOVE,
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -109,12 +124,12 @@ class PatternPairAggregator(BaseTextAggregator):
type: The type of aggregation the matched content represents
(e.g., 'code', 'speaker', 'custom').
action: What to do when a complete pattern is matched:
- "remove": Remove the matched pattern from the text.
- "keep": Keep the matched pattern in the text and treat it as
normal text. This allows you to register handlers for
the pattern without affecting the aggregation logic.
- "aggregate": Return the matched pattern as a separate
aggregation object.
- MatchAction.REMOVE: Remove the matched pattern from the text.
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
normal text. This allows you to register handlers for
the pattern without affecting the aggregation logic.
- MatchAction.AGGREGATE: Return the matched pattern as a separate
aggregation object.
Returns:
Self for method chaining.
@@ -196,7 +211,7 @@ class PatternPairAggregator(BaseTextAggregator):
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
# Remove the pattern from the text if configured
if action == "remove":
if action == MatchAction.REMOVE:
processed_text = processed_text.replace(full_match, "", 1)
# modified = True
else:
@@ -260,15 +275,13 @@ class PatternPairAggregator(BaseTextAggregator):
#
if len(patterns) > 0:
print(f"Found patterns: {[str(p) for p in patterns]}")
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].pattern_id].get("action", "remove")
print(f"Pattern action: {action}")
if action == "aggregate":
action = self._patterns[patterns[0].pattern_id].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
print(f"Returning pattern: {patterns[0]}")
return patterns[0]
@@ -277,7 +290,10 @@ class PatternPairAggregator(BaseTextAggregator):
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, return None
if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate":
if (
pattern_start[0] == 0
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
):
return None
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]

View File

@@ -7,7 +7,11 @@
import unittest
from unittest.mock import AsyncMock
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
@@ -22,14 +26,14 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
start_pattern="<test>",
end_pattern="</test>",
type="test",
action="remove",
action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
pattern_id="code_pattern",
start_pattern="<code>",
end_pattern="</code>",
type="code",
action="aggregate",
action=MatchAction.AGGREGATE,
)
# Register the mock handler
@@ -122,7 +126,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
start_pattern="<voice>",
end_pattern="</voice>",
type="voice",
action="remove",
action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
@@ -130,7 +134,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
start_pattern="<em>",
end_pattern="</em>",
type="emphasis",
action="keep", # Keep emphasis tags
action=MatchAction.KEEP, # Keep emphasis tags
)
self.aggregator.on_pattern_match("voice", voice_handler)