Update PatternPairAggregator patterns to replace pattern_id with type to simplify the API
This commit is contained in:
@@ -111,10 +111,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Add pattern for voice switching
|
||||
pattern_aggregator.add_pattern_pair(
|
||||
pattern_id="voice_tag",
|
||||
type="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
type="voice",
|
||||
action=MatchAction.REMOVE, # Remove tags from final text
|
||||
)
|
||||
|
||||
@@ -130,7 +129,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
else:
|
||||
logger.warning(f"Unknown voice: {voice_name}")
|
||||
|
||||
pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
|
||||
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
|
||||
@@ -118,21 +118,15 @@ class IVRProcessor(FrameProcessor):
|
||||
def _setup_xml_patterns(self):
|
||||
"""Set up XML pattern detection and handlers."""
|
||||
# Register DTMF pattern
|
||||
self._aggregator.add_pattern_pair(
|
||||
"dtmf", "<dtmf>", "</dtmf>", type="dtmf", action=MatchAction.REMOVE
|
||||
)
|
||||
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
|
||||
|
||||
# Register mode pattern
|
||||
self._aggregator.add_pattern_pair(
|
||||
"mode", "<mode>", "</mode>", type="mode", action=MatchAction.REMOVE
|
||||
)
|
||||
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
|
||||
|
||||
# Register IVR pattern
|
||||
self._aggregator.add_pattern_pair(
|
||||
"ivr", "<ivr>", "</ivr>", type="ivr", action=MatchAction.REMOVE
|
||||
)
|
||||
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", action=MatchAction.REMOVE)
|
||||
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -44,27 +44,25 @@ class PatternMatch(Aggregation):
|
||||
content between the patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern_id: str, full_match: str, content: str, type: str):
|
||||
def __init__(self, content: str, type: str, full_match: str):
|
||||
"""Initialize a pattern match.
|
||||
|
||||
Args:
|
||||
pattern_id: The identifier of the matched pattern pair.
|
||||
type: The type of the matched pattern pair. It should be representative
|
||||
of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
|
||||
full_match: The complete text including start and end patterns.
|
||||
content: The text content between the start and end patterns.
|
||||
type: The type of aggregation the matched content represents
|
||||
(e.g., 'code', 'speaker', 'custom').
|
||||
"""
|
||||
super().__init__(text=content, type=type)
|
||||
self.pattern_id = pattern_id
|
||||
self.full_match = full_match
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the pattern match.
|
||||
|
||||
Returns:
|
||||
A descriptive string showing the pattern ID and content.
|
||||
A descriptive string showing the pattern type and content.
|
||||
"""
|
||||
return f"PatternMatch(id={self.pattern_id}, content={self.text}, full_match={self.full_match}, type={self.type})"
|
||||
return f"PatternMatch(type={self.type}, content={self.text}, full_match={self.full_match})"
|
||||
|
||||
|
||||
class PatternPairAggregator(BaseTextAggregator):
|
||||
@@ -110,10 +108,9 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
def add_pattern_pair(
|
||||
self,
|
||||
pattern_id: str,
|
||||
type: str,
|
||||
start_pattern: str,
|
||||
end_pattern: str,
|
||||
type: str,
|
||||
action: MatchAction = MatchAction.REMOVE,
|
||||
) -> "PatternPairAggregator":
|
||||
"""Add a pattern pair to detect in the text.
|
||||
@@ -123,11 +120,11 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
the end pattern, and treat the content between them as a match.
|
||||
|
||||
Args:
|
||||
pattern_id: Unique identifier for this pattern pair.
|
||||
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' as that is
|
||||
reserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
type: The type of aggregation the matched content represents
|
||||
(e.g., 'code', 'speaker', 'custom').
|
||||
action: What to do when a complete pattern is matched:
|
||||
- MatchAction.REMOVE: Remove the matched pattern from the text.
|
||||
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
|
||||
@@ -139,7 +136,11 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._patterns[pattern_id] = {
|
||||
if type == "sentence":
|
||||
raise ValueError(
|
||||
"The aggregation type 'sentence' is reserved for default behavior and can not be used for custom patterns."
|
||||
)
|
||||
self._patterns[type] = {
|
||||
"start": start_pattern,
|
||||
"end": end_pattern,
|
||||
"type": type,
|
||||
@@ -148,22 +149,22 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
return self
|
||||
|
||||
def on_pattern_match(
|
||||
self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
) -> "PatternPairAggregator":
|
||||
"""Register a handler for when a pattern pair is matched.
|
||||
|
||||
The handler will be called whenever a complete match for the
|
||||
specified pattern ID is found in the text.
|
||||
specified type is found in the text.
|
||||
|
||||
Args:
|
||||
pattern_id: ID of the pattern pair to match.
|
||||
type: The type of the pattern pair to trigger the handler.
|
||||
handler: Async function to call when pattern is matched.
|
||||
The function should accept a PatternMatch object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
self._handlers[pattern_id] = handler
|
||||
self._handlers[type] = handler
|
||||
return self
|
||||
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
|
||||
@@ -184,12 +185,11 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
all_matches = []
|
||||
processed_text = text
|
||||
|
||||
for pattern_id, pattern_info in self._patterns.items():
|
||||
for type, pattern_info in self._patterns.items():
|
||||
# Escape special regex characters in the patterns
|
||||
start = re.escape(pattern_info["start"])
|
||||
end = re.escape(pattern_info["end"])
|
||||
action = pattern_info["action"]
|
||||
match_type = pattern_info["type"]
|
||||
|
||||
# Create regex to match from start pattern to end pattern
|
||||
# The .*? is non-greedy to handle nested patterns
|
||||
@@ -204,16 +204,14 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
full_match = match.group(0) # Full match including patterns
|
||||
|
||||
# Create pattern match object
|
||||
pattern_match = PatternMatch(
|
||||
pattern_id=pattern_id, full_match=full_match, content=content, type=match_type
|
||||
)
|
||||
pattern_match = PatternMatch(content=content, type=type, full_match=full_match)
|
||||
|
||||
# Call the appropriate handler if registered
|
||||
if pattern_id in self._handlers:
|
||||
if type in self._handlers:
|
||||
try:
|
||||
await self._handlers[pattern_id](pattern_match)
|
||||
await self._handlers[type](pattern_match)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
|
||||
logger.error(f"Error in pattern handler for {type}: {e}")
|
||||
|
||||
# Remove the pattern from the text if configured
|
||||
if action == MatchAction.REMOVE:
|
||||
@@ -233,10 +231,10 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
text: The text to check for incomplete patterns.
|
||||
|
||||
Returns:
|
||||
A tuple of (start_index, type) if an incomplete pattern is found,
|
||||
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
|
||||
or None if no patterns are found or all patterns are complete.
|
||||
"""
|
||||
for pattern_id, pattern_info in self._patterns.items():
|
||||
for type, pattern_info in self._patterns.items():
|
||||
start = pattern_info["start"]
|
||||
end = pattern_info["end"]
|
||||
|
||||
@@ -280,10 +278,10 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
if len(patterns) > 0:
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
|
||||
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].pattern_id].get("action", MatchAction.REMOVE)
|
||||
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
return patterns[0]
|
||||
@@ -300,7 +298,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
# Otherwise, strip the text up to the start pattern and return it
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
return PatternMatch(f"_sentence", result, result, "sentence")
|
||||
return PatternMatch(content=result, type="sentence", full_match=result)
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
@@ -308,7 +306,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return PatternMatch(f"_sentence", result, result, "sentence")
|
||||
return PatternMatch(content=result, type="sentence", full_match=result)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -22,17 +22,15 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Add a test pattern
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="test_pattern",
|
||||
type="test_pattern",
|
||||
start_pattern="<test>",
|
||||
end_pattern="</test>",
|
||||
type="test",
|
||||
action=MatchAction.REMOVE,
|
||||
)
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="code_pattern",
|
||||
type="code_pattern",
|
||||
start_pattern="<code>",
|
||||
end_pattern="</code>",
|
||||
type="code",
|
||||
action=MatchAction.AGGREGATE,
|
||||
)
|
||||
|
||||
@@ -45,7 +43,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "test")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</test>!")
|
||||
@@ -54,7 +52,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.pattern_id, "test_pattern")
|
||||
self.assertEqual(call_args.type, "test_pattern")
|
||||
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
|
||||
@@ -75,7 +73,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
result = await self.aggregator.aggregate("Here is code <code>pattern")
|
||||
self.assertEqual(result.text, "Here is code ")
|
||||
self.assertEqual(self.aggregator.text.text, "<code>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "code")
|
||||
self.assertEqual(self.aggregator.text.type, "code_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</code>")
|
||||
@@ -84,11 +82,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.code_handler.assert_called_once()
|
||||
call_args = self.code_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.pattern_id, "code_pattern")
|
||||
self.assertEqual(call_args.type, "code_pattern")
|
||||
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
self.assertEqual(result.text, "pattern content")
|
||||
self.assertEqual(result.type, "code")
|
||||
self.assertEqual(result.type, "code_pattern")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
@@ -110,7 +108,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Buffer should contain the incomplete text
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
|
||||
self.assertEqual(self.aggregator.text.type, "test")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
|
||||
# Reset and confirm buffer is cleared
|
||||
await self.aggregator.reset()
|
||||
@@ -122,18 +120,16 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
emphasis_handler = AsyncMock()
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="voice",
|
||||
type="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
type="voice",
|
||||
action=MatchAction.REMOVE,
|
||||
)
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="emphasis",
|
||||
type="emphasis",
|
||||
start_pattern="<em>",
|
||||
end_pattern="</em>",
|
||||
type="emphasis",
|
||||
action=MatchAction.KEEP, # Keep emphasis tags
|
||||
)
|
||||
|
||||
@@ -147,12 +143,12 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Both handlers should be called with correct data
|
||||
voice_handler.assert_called_once()
|
||||
voice_match = voice_handler.call_args[0][0]
|
||||
self.assertEqual(voice_match.pattern_id, "voice")
|
||||
self.assertEqual(voice_match.type, "voice")
|
||||
self.assertEqual(voice_match.text, "female")
|
||||
|
||||
emphasis_handler.assert_called_once()
|
||||
emphasis_match = emphasis_handler.call_args[0][0]
|
||||
self.assertEqual(emphasis_match.pattern_id, "emphasis")
|
||||
self.assertEqual(emphasis_match.type, "emphasis")
|
||||
self.assertEqual(emphasis_match.text, "very")
|
||||
|
||||
# Voice pattern should be removed, emphasis pattern should remain
|
||||
|
||||
Reference in New Issue
Block a user