Update PatternPairAggregator patterns to replace pattern_id with type to simplify the API

This commit is contained in:
mattie ruth backman
2025-11-04 16:22:34 -05:00
parent 82b9c4f0b6
commit e9de9daf8c
4 changed files with 46 additions and 59 deletions

View File

@@ -111,10 +111,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Add pattern for voice switching
pattern_aggregator.add_pattern_pair(
pattern_id="voice_tag",
type="voice",
start_pattern="<voice>",
end_pattern="</voice>",
type="voice",
action=MatchAction.REMOVE, # Remove tags from final text
)
@@ -130,7 +129,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
else:
logger.warning(f"Unknown voice: {voice_name}")
pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))

View File

@@ -118,21 +118,15 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
self._aggregator.add_pattern_pair(
"dtmf", "<dtmf>", "</dtmf>", type="dtmf", action=MatchAction.REMOVE
)
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
self._aggregator.add_pattern_pair(
"mode", "<mode>", "</mode>", type="mode", action=MatchAction.REMOVE
)
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
self._aggregator.add_pattern_pair(
"ivr", "<ivr>", "</ivr>", type="ivr", action=MatchAction.REMOVE
)
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):

View File

@@ -44,27 +44,25 @@ class PatternMatch(Aggregation):
content between the patterns.
"""
def __init__(self, pattern_id: str, full_match: str, content: str, type: str):
def __init__(self, content: str, type: str, full_match: str):
"""Initialize a pattern match.
Args:
pattern_id: The identifier of the matched pattern pair.
type: The type of the matched pattern pair. It should be representative
of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
full_match: The complete text including start and end patterns.
content: The text content between the start and end patterns.
type: The type of aggregation the matched content represents
(e.g., 'code', 'speaker', 'custom').
"""
super().__init__(text=content, type=type)
self.pattern_id = pattern_id
self.full_match = full_match
def __str__(self) -> str:
"""Return a string representation of the pattern match.
Returns:
A descriptive string showing the pattern ID and content.
A descriptive string showing the pattern type and content.
"""
return f"PatternMatch(id={self.pattern_id}, content={self.text}, full_match={self.full_match}, type={self.type})"
return f"PatternMatch(type={self.type}, content={self.text}, full_match={self.full_match})"
class PatternPairAggregator(BaseTextAggregator):
@@ -110,10 +108,9 @@ class PatternPairAggregator(BaseTextAggregator):
def add_pattern_pair(
self,
pattern_id: str,
type: str,
start_pattern: str,
end_pattern: str,
type: str,
action: MatchAction = MatchAction.REMOVE,
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -123,11 +120,11 @@ class PatternPairAggregator(BaseTextAggregator):
the end pattern, and treat the content between them as a match.
Args:
pattern_id: Unique identifier for this pattern pair.
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' as that is
reserved for the default behavior.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
type: The type of aggregation the matched content represents
(e.g., 'code', 'speaker', 'custom').
action: What to do when a complete pattern is matched:
- MatchAction.REMOVE: Remove the matched pattern from the text.
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
@@ -139,7 +136,11 @@ class PatternPairAggregator(BaseTextAggregator):
Returns:
Self for method chaining.
"""
self._patterns[pattern_id] = {
if type == "sentence":
raise ValueError(
"The aggregation type 'sentence' is reserved for default behavior and can not be used for custom patterns."
)
self._patterns[type] = {
"start": start_pattern,
"end": end_pattern,
"type": type,
@@ -148,22 +149,22 @@ class PatternPairAggregator(BaseTextAggregator):
return self
def on_pattern_match(
self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
) -> "PatternPairAggregator":
"""Register a handler for when a pattern pair is matched.
The handler will be called whenever a complete match for the
specified pattern ID is found in the text.
specified type is found in the text.
Args:
pattern_id: ID of the pattern pair to match.
type: The type of the pattern pair to trigger the handler.
handler: Async function to call when pattern is matched.
The function should accept a PatternMatch object.
Returns:
Self for method chaining.
"""
self._handlers[pattern_id] = handler
self._handlers[type] = handler
return self
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
@@ -184,12 +185,11 @@ class PatternPairAggregator(BaseTextAggregator):
all_matches = []
processed_text = text
for pattern_id, pattern_info in self._patterns.items():
for type, pattern_info in self._patterns.items():
# Escape special regex characters in the patterns
start = re.escape(pattern_info["start"])
end = re.escape(pattern_info["end"])
action = pattern_info["action"]
match_type = pattern_info["type"]
# Create regex to match from start pattern to end pattern
# The .*? is non-greedy to handle nested patterns
@@ -204,16 +204,14 @@ class PatternPairAggregator(BaseTextAggregator):
full_match = match.group(0) # Full match including patterns
# Create pattern match object
pattern_match = PatternMatch(
pattern_id=pattern_id, full_match=full_match, content=content, type=match_type
)
pattern_match = PatternMatch(content=content, type=type, full_match=full_match)
# Call the appropriate handler if registered
if pattern_id in self._handlers:
if type in self._handlers:
try:
await self._handlers[pattern_id](pattern_match)
await self._handlers[type](pattern_match)
except Exception as e:
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
logger.error(f"Error in pattern handler for {type}: {e}")
# Remove the pattern from the text if configured
if action == MatchAction.REMOVE:
@@ -233,10 +231,10 @@ class PatternPairAggregator(BaseTextAggregator):
text: The text to check for incomplete patterns.
Returns:
A tuple of (start_index, type) if an incomplete pattern is found,
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
or None if no patterns are found or all patterns are complete.
"""
for pattern_id, pattern_info in self._patterns.items():
for type, pattern_info in self._patterns.items():
start = pattern_info["start"]
end = pattern_info["end"]
@@ -280,10 +278,10 @@ class PatternPairAggregator(BaseTextAggregator):
if len(patterns) > 0:
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].pattern_id].get("action", MatchAction.REMOVE)
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
return patterns[0]
@@ -300,7 +298,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
return PatternMatch(f"_sentence", result, result, "sentence")
return PatternMatch(content=result, type="sentence", full_match=result)
# Find sentence boundary if no incomplete patterns
eos_marker = match_endofsentence(self._text)
@@ -308,7 +306,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return PatternMatch(f"_sentence", result, result, "sentence")
return PatternMatch(content=result, type="sentence", full_match=result)
# No complete sentence found yet
return None

View File

@@ -22,17 +22,15 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Add a test pattern
self.aggregator.add_pattern_pair(
pattern_id="test_pattern",
type="test_pattern",
start_pattern="<test>",
end_pattern="</test>",
type="test",
action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
pattern_id="code_pattern",
type="code_pattern",
start_pattern="<code>",
end_pattern="</code>",
type="code",
action=MatchAction.AGGREGATE,
)
@@ -45,7 +43,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
result = await self.aggregator.aggregate("Hello <test>pattern")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.type, "test")
self.assertEqual(self.aggregator.text.type, "test_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</test>!")
@@ -54,7 +52,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.pattern_id, "test_pattern")
self.assertEqual(call_args.type, "test_pattern")
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
self.assertEqual(call_args.text, "pattern content")
@@ -75,7 +73,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
result = await self.aggregator.aggregate("Here is code <code>pattern")
self.assertEqual(result.text, "Here is code ")
self.assertEqual(self.aggregator.text.text, "<code>pattern")
self.assertEqual(self.aggregator.text.type, "code")
self.assertEqual(self.aggregator.text.type, "code_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</code>")
@@ -84,11 +82,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.code_handler.assert_called_once()
call_args = self.code_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.pattern_id, "code_pattern")
self.assertEqual(call_args.type, "code_pattern")
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
self.assertEqual(call_args.text, "pattern content")
self.assertEqual(result.text, "pattern content")
self.assertEqual(result.type, "code")
self.assertEqual(result.type, "code_pattern")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
@@ -110,7 +108,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Buffer should contain the incomplete text
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
self.assertEqual(self.aggregator.text.type, "test")
self.assertEqual(self.aggregator.text.type, "test_pattern")
# Reset and confirm buffer is cleared
await self.aggregator.reset()
@@ -122,18 +120,16 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
emphasis_handler = AsyncMock()
self.aggregator.add_pattern_pair(
pattern_id="voice",
type="voice",
start_pattern="<voice>",
end_pattern="</voice>",
type="voice",
action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
pattern_id="emphasis",
type="emphasis",
start_pattern="<em>",
end_pattern="</em>",
type="emphasis",
action=MatchAction.KEEP, # Keep emphasis tags
)
@@ -147,12 +143,12 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Both handlers should be called with correct data
voice_handler.assert_called_once()
voice_match = voice_handler.call_args[0][0]
self.assertEqual(voice_match.pattern_id, "voice")
self.assertEqual(voice_match.type, "voice")
self.assertEqual(voice_match.text, "female")
emphasis_handler.assert_called_once()
emphasis_match = emphasis_handler.call_args[0][0]
self.assertEqual(emphasis_match.pattern_id, "emphasis")
self.assertEqual(emphasis_match.type, "emphasis")
self.assertEqual(emphasis_match.text, "very")
# Voice pattern should be removed, emphasis pattern should remain