Various fixes:

1. Fixed pattern_pair_aggregator to support various ways of handling
   pattern matches (remove, keep and just trigger a callback, or
   aggregate
2. Fixed ivr_navigator use of pattern_pair_aggregator
3. Test fixes -- Tests now pass
This commit is contained in:
mattie ruth backman
2025-10-23 12:24:04 -04:00
parent 5c8635570d
commit 69945c5e0d
6 changed files with 95 additions and 48 deletions

View File

@@ -111,7 +111,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
start_pattern="<voice>",
end_pattern="</voice>",
type="voice",
remove_match=True,
action="remove", # Remove tags from final text
)
# Register handler for voice switching

View File

@@ -114,19 +114,15 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
self._aggregator.add_pattern_pair(
"dtmf", "<dtmf>", "</dtmf>", type="dtmf", remove_match=True
)
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", type="dtmf", action="remove")
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
self._aggregator.add_pattern_pair(
"mode", "<mode>", "</mode>", type="mode", remove_match=True
)
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", type="mode", action="remove")
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", type="ivr", remove_match=True)
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", type="ivr", action="remove")
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -163,7 +159,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing DTMF content.
"""
value = match.content
value = match.text
logger.debug(f"DTMF detected: {value}")
try:
@@ -184,7 +180,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing IVR status content.
"""
status = match.content
status = match.text
logger.trace(f"IVR status detected: {status}")
# Convert string to enum, with validation
@@ -215,7 +211,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing mode content.
"""
mode = match.content
mode = match.text
logger.debug(f"Mode detected: {mode}")
if mode == "conversation":
await self._handle_conversation()

View File

@@ -12,7 +12,7 @@ support for custom handlers and configurable pattern removal.
"""
import re
from typing import Awaitable, Callable, List, Optional, Tuple
from typing import Awaitable, Callable, List, Literal, Optional, Tuple
from loguru import logger
@@ -83,9 +83,9 @@ class PatternPairAggregator(BaseTextAggregator):
Returns:
The text that has been accumulated in the buffer.
"""
start, curtype = self._match_start_of_pattern(self._text)
if curtype:
return Aggregation(self._text, curtype)
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start:
return Aggregation(self._text, pattern_start[1].get("type", "sentence"))
return Aggregation(self._text, "sentence")
def add_pattern_pair(
@@ -94,7 +94,7 @@ class PatternPairAggregator(BaseTextAggregator):
start_pattern: str,
end_pattern: str,
type: str,
remove_match: bool = True,
action: Literal["remove", "keep", "aggregate"] = "remove",
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -108,7 +108,12 @@ class PatternPairAggregator(BaseTextAggregator):
end_pattern: Pattern that marks the end of content.
type: The type of aggregation the matched content represents
(e.g., 'code', 'speaker', 'custom').
remove_match: Whether to remove the matched content from the text returned.
action: What to do when a complete pattern is matched:
- "remove": Remove the matched pattern from the text.
- "keep": Keep the matched pattern in the text and treat it as
normal text.
- "aggregate": Return the matched pattern as a separate
aggregation object.
Returns:
Self for method chaining.
@@ -117,7 +122,7 @@ class PatternPairAggregator(BaseTextAggregator):
"start": start_pattern,
"end": end_pattern,
"type": type,
"remove_match": remove_match,
"action": action,
}
return self
@@ -162,7 +167,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Escape special regex characters in the patterns
start = re.escape(pattern_info["start"])
end = re.escape(pattern_info["end"])
remove_match = pattern_info["remove_match"]
action = pattern_info["action"]
match_type = pattern_info["type"]
# Create regex to match from start pattern to end pattern
@@ -190,7 +195,7 @@ class PatternPairAggregator(BaseTextAggregator):
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
# Remove the pattern from the text if configured
if remove_match:
if action == "remove":
processed_text = processed_text.replace(full_match, "", 1)
# modified = True
else:
@@ -198,7 +203,7 @@ class PatternPairAggregator(BaseTextAggregator):
return all_matches, processed_text
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, str]]:
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
"""Check if text contains incomplete pattern pairs.
Determines whether the text contains any start patterns without
@@ -225,9 +230,9 @@ class PatternPairAggregator(BaseTextAggregator):
# Which is why we base the return on the first found.
if start_count > end_count:
start_index = text.find(start)
return [start_index, pattern_info["type"]]
return [start_index, pattern_info]
return None, None
return None
async def aggregate(self, text: str) -> Optional[PatternMatch]:
"""Aggregate text and process pattern pairs.
@@ -258,17 +263,22 @@ class PatternPairAggregator(BaseTextAggregator):
logger.warning(
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
)
self._text = ""
return patterns[0]
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].pattern_id].get("action", "remove")
if action == "aggregate":
self._text = ""
print(f"Returning pattern: {patterns[0]}")
return patterns[0]
# Check if we have incomplete patterns
start, curtype = self._match_start_of_pattern(self._text)
if start is not None:
# Still waiting for complete patterns
if start == 0:
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, return None
if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate":
return None
result = self._text[:start]
self._text = self._text[start:]
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
return PatternMatch(f"_sentence", result, result, "sentence")
# Find sentence boundary if no incomplete patterns

View File

@@ -14,6 +14,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.aggregator = PatternPairAggregator()
self.test_handler = AsyncMock()
self.code_handler = AsyncMock()
# Add a test pattern
self.aggregator.add_pattern_pair(
@@ -21,22 +22,24 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
start_pattern="<test>",
end_pattern="</test>",
type="test",
remove_match=True,
action="remove",
)
self.aggregator.add_pattern_pair(
pattern_id="code_pattern",
start_pattern="<code>",
end_pattern="</code>",
type="code",
remove_match=False,
action="aggregate",
)
# Register the mock handler
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
async def test_pattern_match_and_removal(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Hello <test>pattern")
print(f"result: {result}")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.type, "test")
@@ -50,7 +53,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.pattern_id, "test_pattern")
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
self.assertEqual(call_args.content, "pattern content")
self.assertEqual(call_args.text, "pattern content")
# The exclamation point should be treated as a sentence boundary,
# so the result should include just text up to and including "!"
@@ -64,6 +67,33 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text.text, "")
async def test_pattern_match_and_aggregate(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Here is code <code>pattern")
print(f"result: {result}")
self.assertEqual(result.text, "Here is code ")
self.assertEqual(self.aggregator.text.text, "<code>pattern")
self.assertEqual(self.aggregator.text.type, "code")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</code>")
# Verify the handler was called with correct PatternMatch object
self.code_handler.assert_called_once()
call_args = self.code_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.pattern_id, "code_pattern")
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
self.assertEqual(call_args.text, "pattern content")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
self.assertEqual(result.text, " This is another sentence.")
self.assertEqual(result.type, "sentence")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text.text, "")
async def test_incomplete_pattern(self):
# Add text with incomplete pattern
result = await self.aggregator.aggregate("Hello <test>pattern content")
@@ -88,14 +118,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
emphasis_handler = AsyncMock()
self.aggregator.add_pattern_pair(
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
pattern_id="voice",
start_pattern="<voice>",
end_pattern="</voice>",
type="voice",
action="remove",
)
self.aggregator.add_pattern_pair(
pattern_id="emphasis",
start_pattern="<em>",
end_pattern="</em>",
remove_match=False, # Keep emphasis tags
type="emphasis",
action="keep", # Keep emphasis tags
)
self.aggregator.on_pattern_match("voice", voice_handler)
@@ -109,15 +144,15 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
voice_handler.assert_called_once()
voice_match = voice_handler.call_args[0][0]
self.assertEqual(voice_match.pattern_id, "voice")
self.assertEqual(voice_match.content, "female")
self.assertEqual(voice_match.text, "female")
emphasis_handler.assert_called_once()
emphasis_match = emphasis_handler.call_args[0][0]
self.assertEqual(emphasis_match.pattern_id, "emphasis")
self.assertEqual(emphasis_match.content, "very")
self.assertEqual(emphasis_match.text, "very")
# Voice pattern should be removed, emphasis pattern should remain
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
# Buffer should be empty
self.assertEqual(self.aggregator.text.text, "")
@@ -149,10 +184,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Handler should be called with entire content
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
# Pattern should be removed, resulting in text with sentences merged
self.assertEqual(result, "Hello Final sentence.")
self.assertEqual(result.text, "Hello Final sentence.")
# Buffer should be empty
self.assertEqual(self.aggregator.text.text, "")

View File

@@ -74,6 +74,7 @@ async def test_run_piper_tts_success(aiohttp_client):
]
expected_returned_frames = [
TTSTextFrame,
TTSStartedFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
@@ -121,7 +122,7 @@ async def test_run_piper_tts_error(aiohttp_client):
TTSSpeakFrame(text="Error case."),
]
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
expected_down_frames = [TTSTextFrame, TTSStoppedFrame, TTSTextFrame]
expected_up_frames = [ErrorFrame]

View File

@@ -15,15 +15,20 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reset_aggregations(self):
assert await self.aggregator.aggregate("Hello ") == None
assert self.aggregator.text == "Hello "
assert self.aggregator.text.text == "Hello "
await self.aggregator.reset()
assert self.aggregator.text == ""
assert self.aggregator.text.text == ""
async def test_simple_sentence(self):
assert await self.aggregator.aggregate("Hello ") == None
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
assert self.aggregator.text == ""
aggregate = await self.aggregator.aggregate("Pipecat!")
assert aggregate.text == "Hello Pipecat!"
assert aggregate.type == "sentence"
assert self.aggregator.text.text == ""
async def test_multiple_sentences(self):
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
assert await self.aggregator.aggregate("you?") == " How are you?"
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
assert aggregate.text == "Hello Pipecat!"
assert self.aggregator.text.text == " How are "
aggregate = await self.aggregator.aggregate("you?")
assert aggregate.text == " How are you?"