Various fixes:
1. Fixed pattern_pair_aggregator to support various ways of handling pattern matches (remove, keep and just trigger a callback, or aggregate 2. Fixed ivr_navigator use of pattern_pair_aggregator 3. Test fixes -- Tests now pass
This commit is contained in:
@@ -111,7 +111,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
type="voice",
|
||||
remove_match=True,
|
||||
action="remove", # Remove tags from final text
|
||||
)
|
||||
|
||||
# Register handler for voice switching
|
||||
|
||||
@@ -114,19 +114,15 @@ class IVRProcessor(FrameProcessor):
|
||||
def _setup_xml_patterns(self):
|
||||
"""Set up XML pattern detection and handlers."""
|
||||
# Register DTMF pattern
|
||||
self._aggregator.add_pattern_pair(
|
||||
"dtmf", "<dtmf>", "</dtmf>", type="dtmf", remove_match=True
|
||||
)
|
||||
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", type="dtmf", action="remove")
|
||||
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
|
||||
|
||||
# Register mode pattern
|
||||
self._aggregator.add_pattern_pair(
|
||||
"mode", "<mode>", "</mode>", type="mode", remove_match=True
|
||||
)
|
||||
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", type="mode", action="remove")
|
||||
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
|
||||
|
||||
# Register IVR pattern
|
||||
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", type="ivr", remove_match=True)
|
||||
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", type="ivr", action="remove")
|
||||
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -163,7 +159,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing DTMF content.
|
||||
"""
|
||||
value = match.content
|
||||
value = match.text
|
||||
logger.debug(f"DTMF detected: {value}")
|
||||
|
||||
try:
|
||||
@@ -184,7 +180,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing IVR status content.
|
||||
"""
|
||||
status = match.content
|
||||
status = match.text
|
||||
logger.trace(f"IVR status detected: {status}")
|
||||
|
||||
# Convert string to enum, with validation
|
||||
@@ -215,7 +211,7 @@ class IVRProcessor(FrameProcessor):
|
||||
Args:
|
||||
match: The pattern match containing mode content.
|
||||
"""
|
||||
mode = match.content
|
||||
mode = match.text
|
||||
logger.debug(f"Mode detected: {mode}")
|
||||
if mode == "conversation":
|
||||
await self._handle_conversation()
|
||||
|
||||
@@ -12,7 +12,7 @@ support for custom handlers and configurable pattern removal.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import Awaitable, Callable, List, Literal, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -83,9 +83,9 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
Returns:
|
||||
The text that has been accumulated in the buffer.
|
||||
"""
|
||||
start, curtype = self._match_start_of_pattern(self._text)
|
||||
if curtype:
|
||||
return Aggregation(self._text, curtype)
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start:
|
||||
return Aggregation(self._text, pattern_start[1].get("type", "sentence"))
|
||||
return Aggregation(self._text, "sentence")
|
||||
|
||||
def add_pattern_pair(
|
||||
@@ -94,7 +94,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
start_pattern: str,
|
||||
end_pattern: str,
|
||||
type: str,
|
||||
remove_match: bool = True,
|
||||
action: Literal["remove", "keep", "aggregate"] = "remove",
|
||||
) -> "PatternPairAggregator":
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
@@ -108,7 +108,12 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
type: The type of aggregation the matched content represents
|
||||
(e.g., 'code', 'speaker', 'custom').
|
||||
remove_match: Whether to remove the matched content from the text returned.
|
||||
action: What to do when a complete pattern is matched:
|
||||
- "remove": Remove the matched pattern from the text.
|
||||
- "keep": Keep the matched pattern in the text and treat it as
|
||||
normal text.
|
||||
- "aggregate": Return the matched pattern as a separate
|
||||
aggregation object.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
@@ -117,7 +122,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
"start": start_pattern,
|
||||
"end": end_pattern,
|
||||
"type": type,
|
||||
"remove_match": remove_match,
|
||||
"action": action,
|
||||
}
|
||||
return self
|
||||
|
||||
@@ -162,7 +167,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
# Escape special regex characters in the patterns
|
||||
start = re.escape(pattern_info["start"])
|
||||
end = re.escape(pattern_info["end"])
|
||||
remove_match = pattern_info["remove_match"]
|
||||
action = pattern_info["action"]
|
||||
match_type = pattern_info["type"]
|
||||
|
||||
# Create regex to match from start pattern to end pattern
|
||||
@@ -190,7 +195,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
|
||||
|
||||
# Remove the pattern from the text if configured
|
||||
if remove_match:
|
||||
if action == "remove":
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
# modified = True
|
||||
else:
|
||||
@@ -198,7 +203,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
return all_matches, processed_text
|
||||
|
||||
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, str]]:
|
||||
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
|
||||
"""Check if text contains incomplete pattern pairs.
|
||||
|
||||
Determines whether the text contains any start patterns without
|
||||
@@ -225,9 +230,9 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
# Which is why we base the return on the first found.
|
||||
if start_count > end_count:
|
||||
start_index = text.find(start)
|
||||
return [start_index, pattern_info["type"]]
|
||||
return [start_index, pattern_info]
|
||||
|
||||
return None, None
|
||||
return None
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[PatternMatch]:
|
||||
"""Aggregate text and process pattern pairs.
|
||||
@@ -258,17 +263,22 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
|
||||
)
|
||||
self._text = ""
|
||||
return patterns[0]
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].pattern_id].get("action", "remove")
|
||||
if action == "aggregate":
|
||||
self._text = ""
|
||||
print(f"Returning pattern: {patterns[0]}")
|
||||
return patterns[0]
|
||||
|
||||
# Check if we have incomplete patterns
|
||||
start, curtype = self._match_start_of_pattern(self._text)
|
||||
if start is not None:
|
||||
# Still waiting for complete patterns
|
||||
if start == 0:
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, return None
|
||||
if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate":
|
||||
return None
|
||||
result = self._text[:start]
|
||||
self._text = self._text[start:]
|
||||
# Otherwise, strip the text up to the start pattern and return it
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
return PatternMatch(f"_sentence", result, result, "sentence")
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
|
||||
@@ -14,6 +14,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.aggregator = PatternPairAggregator()
|
||||
self.test_handler = AsyncMock()
|
||||
self.code_handler = AsyncMock()
|
||||
|
||||
# Add a test pattern
|
||||
self.aggregator.add_pattern_pair(
|
||||
@@ -21,22 +22,24 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
start_pattern="<test>",
|
||||
end_pattern="</test>",
|
||||
type="test",
|
||||
remove_match=True,
|
||||
action="remove",
|
||||
)
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="code_pattern",
|
||||
start_pattern="<code>",
|
||||
end_pattern="</code>",
|
||||
type="code",
|
||||
remove_match=False,
|
||||
action="aggregate",
|
||||
)
|
||||
|
||||
# Register the mock handler
|
||||
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
|
||||
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
|
||||
|
||||
async def test_pattern_match_and_removal(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
print(f"result: {result}")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "test")
|
||||
@@ -50,7 +53,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.pattern_id, "test_pattern")
|
||||
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
|
||||
self.assertEqual(call_args.content, "pattern content")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
|
||||
# The exclamation point should be treated as a sentence boundary,
|
||||
# so the result should include just text up to and including "!"
|
||||
@@ -64,6 +67,33 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_pattern_match_and_aggregate(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Here is code <code>pattern")
|
||||
print(f"result: {result}")
|
||||
self.assertEqual(result.text, "Here is code ")
|
||||
self.assertEqual(self.aggregator.text.text, "<code>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "code")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</code>")
|
||||
|
||||
# Verify the handler was called with correct PatternMatch object
|
||||
self.code_handler.assert_called_once()
|
||||
call_args = self.code_handler.call_args[0][0]
|
||||
self.assertIsInstance(call_args, PatternMatch)
|
||||
self.assertEqual(call_args.pattern_id, "code_pattern")
|
||||
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
self.assertEqual(result.text, " This is another sentence.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_incomplete_pattern(self):
|
||||
# Add text with incomplete pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern content")
|
||||
@@ -88,14 +118,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
emphasis_handler = AsyncMock()
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
|
||||
pattern_id="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
type="voice",
|
||||
action="remove",
|
||||
)
|
||||
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="emphasis",
|
||||
start_pattern="<em>",
|
||||
end_pattern="</em>",
|
||||
remove_match=False, # Keep emphasis tags
|
||||
type="emphasis",
|
||||
action="keep", # Keep emphasis tags
|
||||
)
|
||||
|
||||
self.aggregator.on_pattern_match("voice", voice_handler)
|
||||
@@ -109,15 +144,15 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
voice_handler.assert_called_once()
|
||||
voice_match = voice_handler.call_args[0][0]
|
||||
self.assertEqual(voice_match.pattern_id, "voice")
|
||||
self.assertEqual(voice_match.content, "female")
|
||||
self.assertEqual(voice_match.text, "female")
|
||||
|
||||
emphasis_handler.assert_called_once()
|
||||
emphasis_match = emphasis_handler.call_args[0][0]
|
||||
self.assertEqual(emphasis_match.pattern_id, "emphasis")
|
||||
self.assertEqual(emphasis_match.content, "very")
|
||||
self.assertEqual(emphasis_match.text, "very")
|
||||
|
||||
# Voice pattern should be removed, emphasis pattern should remain
|
||||
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
|
||||
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
@@ -149,10 +184,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Handler should be called with entire content
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
|
||||
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
|
||||
|
||||
# Pattern should be removed, resulting in text with sentences merged
|
||||
self.assertEqual(result, "Hello Final sentence.")
|
||||
self.assertEqual(result.text, "Hello Final sentence.")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
@@ -74,6 +74,7 @@ async def test_run_piper_tts_success(aiohttp_client):
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
TTSTextFrame,
|
||||
TTSStartedFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSAudioRawFrame,
|
||||
@@ -121,7 +122,7 @@ async def test_run_piper_tts_error(aiohttp_client):
|
||||
TTSSpeakFrame(text="Error case."),
|
||||
]
|
||||
|
||||
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
|
||||
expected_down_frames = [TTSTextFrame, TTSStoppedFrame, TTSTextFrame]
|
||||
|
||||
expected_up_frames = [ErrorFrame]
|
||||
|
||||
|
||||
@@ -15,15 +15,20 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_reset_aggregations(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert self.aggregator.text == "Hello "
|
||||
assert self.aggregator.text.text == "Hello "
|
||||
await self.aggregator.reset()
|
||||
assert self.aggregator.text == ""
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_simple_sentence(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
|
||||
assert self.aggregator.text == ""
|
||||
aggregate = await self.aggregator.aggregate("Pipecat!")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert aggregate.type == "sentence"
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_multiple_sentences(self):
|
||||
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
|
||||
assert await self.aggregator.aggregate("you?") == " How are you?"
|
||||
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert self.aggregator.text.text == " How are "
|
||||
aggregate = await self.aggregator.aggregate("you?")
|
||||
assert aggregate.text == " How are you?"
|
||||
|
||||
Reference in New Issue
Block a user