Make the PatternPair action an Enum
This commit is contained in:
@@ -62,7 +62,11 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -111,7 +115,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
type="voice",
|
||||
action="remove", # Remove tags from final text
|
||||
action=MatchAction.REMOVE, # Remove tags from final text
|
||||
)
|
||||
|
||||
# Register handler for voice switching
|
||||
|
||||
@@ -31,7 +31,11 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
|
||||
class IVRStatus(Enum):
|
||||
@@ -114,15 +118,21 @@ class IVRProcessor(FrameProcessor):
|
||||
def _setup_xml_patterns(self):
|
||||
"""Set up XML pattern detection and handlers."""
|
||||
# Register DTMF pattern
|
||||
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", type="dtmf", action="remove")
|
||||
self._aggregator.add_pattern_pair(
|
||||
"dtmf", "<dtmf>", "</dtmf>", type="dtmf", action=MatchAction.REMOVE
|
||||
)
|
||||
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
|
||||
|
||||
# Register mode pattern
|
||||
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", type="mode", action="remove")
|
||||
self._aggregator.add_pattern_pair(
|
||||
"mode", "<mode>", "</mode>", type="mode", action=MatchAction.REMOVE
|
||||
)
|
||||
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
|
||||
|
||||
# Register IVR pattern
|
||||
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", type="ivr", action="remove")
|
||||
self._aggregator.add_pattern_pair(
|
||||
"ivr", "<ivr>", "</ivr>", type="ivr", action=MatchAction.REMOVE
|
||||
)
|
||||
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -12,7 +12,8 @@ support for custom handlers and configurable pattern removal.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Awaitable, Callable, List, Literal, Optional, Tuple
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -20,6 +21,20 @@ from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator
|
||||
|
||||
|
||||
class MatchAction(Enum):
|
||||
"""Actions to take when a pattern pair is matched.
|
||||
|
||||
Parameters:
|
||||
REMOVE: Remove the matched pattern from the text.
|
||||
KEEP: Keep the matched pattern in the text as normal text.
|
||||
AGGREGATE: Return the matched pattern as a separate aggregation object.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
KEEP = "keep"
|
||||
AGGREGATE = "aggregate"
|
||||
|
||||
|
||||
class PatternMatch(Aggregation):
|
||||
"""Represents a matched pattern pair with its content.
|
||||
|
||||
@@ -94,7 +109,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
start_pattern: str,
|
||||
end_pattern: str,
|
||||
type: str,
|
||||
action: Literal["remove", "keep", "aggregate"] = "remove",
|
||||
action: MatchAction = MatchAction.REMOVE,
|
||||
) -> "PatternPairAggregator":
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
@@ -109,12 +124,12 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
type: The type of aggregation the matched content represents
|
||||
(e.g., 'code', 'speaker', 'custom').
|
||||
action: What to do when a complete pattern is matched:
|
||||
- "remove": Remove the matched pattern from the text.
|
||||
- "keep": Keep the matched pattern in the text and treat it as
|
||||
normal text. This allows you to register handlers for
|
||||
the pattern without affecting the aggregation logic.
|
||||
- "aggregate": Return the matched pattern as a separate
|
||||
aggregation object.
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
|
||||
normal text. This allows you to register handlers for
|
||||
the pattern without affecting the aggregation logic.
|
||||
- MatchAction.AGGREGATE: Return the matched pattern as a separate
|
||||
aggregation object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
@@ -196,7 +211,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
|
||||
|
||||
# Remove the pattern from the text if configured
|
||||
if action == "remove":
|
||||
if action == MatchAction.REMOVE:
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
# modified = True
|
||||
else:
|
||||
@@ -260,15 +275,13 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
#
|
||||
if len(patterns) > 0:
|
||||
print(f"Found patterns: {[str(p) for p in patterns]}")
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].pattern_id].get("action", "remove")
|
||||
print(f"Pattern action: {action}")
|
||||
if action == "aggregate":
|
||||
action = self._patterns[patterns[0].pattern_id].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
print(f"Returning pattern: {patterns[0]}")
|
||||
return patterns[0]
|
||||
@@ -277,7 +290,10 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, return None
|
||||
if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate":
|
||||
if (
|
||||
pattern_start[0] == 0
|
||||
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
|
||||
):
|
||||
return None
|
||||
# Otherwise, strip the text up to the start pattern and return it
|
||||
result = self._text[: pattern_start[0]]
|
||||
|
||||
@@ -7,7 +7,11 @@
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
|
||||
from pipecat.utils.text.pattern_pair_aggregator import (
|
||||
MatchAction,
|
||||
PatternMatch,
|
||||
PatternPairAggregator,
|
||||
)
|
||||
|
||||
|
||||
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -22,14 +26,14 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
start_pattern="<test>",
|
||||
end_pattern="</test>",
|
||||
type="test",
|
||||
action="remove",
|
||||
action=MatchAction.REMOVE,
|
||||
)
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="code_pattern",
|
||||
start_pattern="<code>",
|
||||
end_pattern="</code>",
|
||||
type="code",
|
||||
action="aggregate",
|
||||
action=MatchAction.AGGREGATE,
|
||||
)
|
||||
|
||||
# Register the mock handler
|
||||
@@ -122,7 +126,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
type="voice",
|
||||
action="remove",
|
||||
action=MatchAction.REMOVE,
|
||||
)
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
@@ -130,7 +134,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
start_pattern="<em>",
|
||||
end_pattern="</em>",
|
||||
type="emphasis",
|
||||
action="keep", # Keep emphasis tags
|
||||
action=MatchAction.KEEP, # Keep emphasis tags
|
||||
)
|
||||
|
||||
self.aggregator.on_pattern_match("voice", voice_handler)
|
||||
|
||||
Reference in New Issue
Block a user