diff --git a/examples/foundational/35-pattern-pair-voice-switching.py b/examples/foundational/35-pattern-pair-voice-switching.py
index 013c3eecd..550520317 100644
--- a/examples/foundational/35-pattern-pair-voice-switching.py
+++ b/examples/foundational/35-pattern-pair-voice-switching.py
@@ -111,7 +111,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
start_pattern="",
end_pattern="",
type="voice",
- remove_match=True,
+ action="remove", # Remove tags from final text
)
# Register handler for voice switching
diff --git a/src/pipecat/extensions/ivr/ivr_navigator.py b/src/pipecat/extensions/ivr/ivr_navigator.py
index c4e31e12a..fac760ab3 100644
--- a/src/pipecat/extensions/ivr/ivr_navigator.py
+++ b/src/pipecat/extensions/ivr/ivr_navigator.py
@@ -114,19 +114,15 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
- self._aggregator.add_pattern_pair(
- "dtmf", "", "", type="dtmf", remove_match=True
- )
+ self._aggregator.add_pattern_pair("dtmf", "", "", type="dtmf", action="remove")
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
- self._aggregator.add_pattern_pair(
- "mode", "", "", type="mode", remove_match=True
- )
+ self._aggregator.add_pattern_pair("mode", "", "", type="mode", action="remove")
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
- self._aggregator.add_pattern_pair("ivr", "", "", type="ivr", remove_match=True)
+ self._aggregator.add_pattern_pair("ivr", "", "", type="ivr", action="remove")
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -163,7 +159,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing DTMF content.
"""
- value = match.content
+ value = match.text
logger.debug(f"DTMF detected: {value}")
try:
@@ -184,7 +180,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing IVR status content.
"""
- status = match.content
+ status = match.text
logger.trace(f"IVR status detected: {status}")
# Convert string to enum, with validation
@@ -215,7 +211,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing mode content.
"""
- mode = match.content
+ mode = match.text
logger.debug(f"Mode detected: {mode}")
if mode == "conversation":
await self._handle_conversation()
diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py
index fe32f5b51..c188c3974 100644
--- a/src/pipecat/utils/text/pattern_pair_aggregator.py
+++ b/src/pipecat/utils/text/pattern_pair_aggregator.py
@@ -12,7 +12,7 @@ support for custom handlers and configurable pattern removal.
"""
import re
-from typing import Awaitable, Callable, List, Optional, Tuple
+from typing import Awaitable, Callable, List, Literal, Optional, Tuple
from loguru import logger
@@ -83,9 +83,9 @@ class PatternPairAggregator(BaseTextAggregator):
Returns:
The text that has been accumulated in the buffer.
"""
- start, curtype = self._match_start_of_pattern(self._text)
- if curtype:
- return Aggregation(self._text, curtype)
+ pattern_start = self._match_start_of_pattern(self._text)
+ if pattern_start:
+ return Aggregation(self._text, pattern_start[1].get("type", "sentence"))
return Aggregation(self._text, "sentence")
def add_pattern_pair(
@@ -94,7 +94,7 @@ class PatternPairAggregator(BaseTextAggregator):
start_pattern: str,
end_pattern: str,
type: str,
- remove_match: bool = True,
+ action: Literal["remove", "keep", "aggregate"] = "remove",
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -108,7 +108,12 @@ class PatternPairAggregator(BaseTextAggregator):
end_pattern: Pattern that marks the end of content.
type: The type of aggregation the matched content represents
(e.g., 'code', 'speaker', 'custom').
- remove_match: Whether to remove the matched content from the text returned.
+ action: What to do when a complete pattern is matched:
+ - "remove": Remove the matched pattern from the text.
+ - "keep": Keep the matched pattern in the text and treat it as
+ normal text.
+ - "aggregate": Return the matched pattern as a separate
+ aggregation object.
Returns:
Self for method chaining.
@@ -117,7 +122,7 @@ class PatternPairAggregator(BaseTextAggregator):
"start": start_pattern,
"end": end_pattern,
"type": type,
- "remove_match": remove_match,
+ "action": action,
}
return self
@@ -162,7 +167,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Escape special regex characters in the patterns
start = re.escape(pattern_info["start"])
end = re.escape(pattern_info["end"])
- remove_match = pattern_info["remove_match"]
+ action = pattern_info["action"]
match_type = pattern_info["type"]
# Create regex to match from start pattern to end pattern
@@ -190,7 +195,7 @@ class PatternPairAggregator(BaseTextAggregator):
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
# Remove the pattern from the text if configured
- if remove_match:
+ if action == "remove":
processed_text = processed_text.replace(full_match, "", 1)
# modified = True
else:
@@ -198,7 +203,7 @@ class PatternPairAggregator(BaseTextAggregator):
return all_matches, processed_text
- def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, str]]:
+ def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
"""Check if text contains incomplete pattern pairs.
Determines whether the text contains any start patterns without
@@ -225,9 +230,9 @@ class PatternPairAggregator(BaseTextAggregator):
# Which is why we base the return on the first found.
if start_count > end_count:
start_index = text.find(start)
- return [start_index, pattern_info["type"]]
+ return [start_index, pattern_info]
- return None, None
+ return None
async def aggregate(self, text: str) -> Optional[PatternMatch]:
"""Aggregate text and process pattern pairs.
@@ -258,17 +263,22 @@ class PatternPairAggregator(BaseTextAggregator):
logger.warning(
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
)
- self._text = ""
- return patterns[0]
+ # If the pattern found is set to be aggregated, return it
+ action = self._patterns[patterns[0].pattern_id].get("action", "remove")
+ if action == "aggregate":
+ self._text = ""
+ print(f"Returning pattern: {patterns[0]}")
+ return patterns[0]
# Check if we have incomplete patterns
- start, curtype = self._match_start_of_pattern(self._text)
- if start is not None:
- # Still waiting for complete patterns
- if start == 0:
+ pattern_start = self._match_start_of_pattern(self._text)
+ if pattern_start is not None:
+ # If the start pattern is at the beginning or should not be separately aggregated, return None
+ if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate":
return None
- result = self._text[:start]
- self._text = self._text[start:]
+ # Otherwise, strip the text up to the start pattern and return it
+ result = self._text[: pattern_start[0]]
+ self._text = self._text[pattern_start[0] :]
return PatternMatch(f"_sentence", result, result, "sentence")
# Find sentence boundary if no incomplete patterns
diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py
index d0e13ffc6..310ccf635 100644
--- a/tests/test_pattern_pair_aggregator.py
+++ b/tests/test_pattern_pair_aggregator.py
@@ -14,6 +14,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.aggregator = PatternPairAggregator()
self.test_handler = AsyncMock()
+ self.code_handler = AsyncMock()
# Add a test pattern
self.aggregator.add_pattern_pair(
@@ -21,22 +22,24 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
start_pattern="",
end_pattern="",
type="test",
- remove_match=True,
+ action="remove",
)
self.aggregator.add_pattern_pair(
pattern_id="code_pattern",
start_pattern="",
end_pattern="",
type="code",
- remove_match=False,
+ action="aggregate",
)
# Register the mock handler
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
+ self.aggregator.on_pattern_match("code_pattern", self.code_handler)
async def test_pattern_match_and_removal(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Hello pattern")
+ print(f"result: {result}")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text.text, "Hello pattern")
self.assertEqual(self.aggregator.text.type, "test")
@@ -50,7 +53,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.pattern_id, "test_pattern")
self.assertEqual(call_args.full_match, "pattern content")
- self.assertEqual(call_args.content, "pattern content")
+ self.assertEqual(call_args.text, "pattern content")
# The exclamation point should be treated as a sentence boundary,
# so the result should include just text up to and including "!"
@@ -64,6 +67,33 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text.text, "")
+ async def test_pattern_match_and_aggregate(self):
+ # First part doesn't complete the pattern
+ result = await self.aggregator.aggregate("Here is code pattern")
+ print(f"result: {result}")
+ self.assertEqual(result.text, "Here is code ")
+ self.assertEqual(self.aggregator.text.text, "pattern")
+ self.assertEqual(self.aggregator.text.type, "code")
+
+ # Second part completes the pattern and includes an exclamation point
+ result = await self.aggregator.aggregate(" content")
+
+ # Verify the handler was called with correct PatternMatch object
+ self.code_handler.assert_called_once()
+ call_args = self.code_handler.call_args[0][0]
+ self.assertIsInstance(call_args, PatternMatch)
+ self.assertEqual(call_args.pattern_id, "code_pattern")
+ self.assertEqual(call_args.full_match, "pattern content")
+ self.assertEqual(call_args.text, "pattern content")
+
+ # Next sentence should be processed separately
+ result = await self.aggregator.aggregate(" This is another sentence.")
+ self.assertEqual(result.text, " This is another sentence.")
+ self.assertEqual(result.type, "sentence")
+
+ # Buffer should be empty after returning a complete sentence
+ self.assertEqual(self.aggregator.text.text, "")
+
async def test_incomplete_pattern(self):
# Add text with incomplete pattern
result = await self.aggregator.aggregate("Hello pattern content")
@@ -88,14 +118,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
emphasis_handler = AsyncMock()
self.aggregator.add_pattern_pair(
- pattern_id="voice", start_pattern="", end_pattern="", remove_match=True
+ pattern_id="voice",
+ start_pattern="",
+ end_pattern="",
+ type="voice",
+ action="remove",
)
self.aggregator.add_pattern_pair(
pattern_id="emphasis",
start_pattern="",
end_pattern="",
- remove_match=False, # Keep emphasis tags
+ type="emphasis",
+ action="keep", # Keep emphasis tags
)
self.aggregator.on_pattern_match("voice", voice_handler)
@@ -109,15 +144,15 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
voice_handler.assert_called_once()
voice_match = voice_handler.call_args[0][0]
self.assertEqual(voice_match.pattern_id, "voice")
- self.assertEqual(voice_match.content, "female")
+ self.assertEqual(voice_match.text, "female")
emphasis_handler.assert_called_once()
emphasis_match = emphasis_handler.call_args[0][0]
self.assertEqual(emphasis_match.pattern_id, "emphasis")
- self.assertEqual(emphasis_match.content, "very")
+ self.assertEqual(emphasis_match.text, "very")
# Voice pattern should be removed, emphasis pattern should remain
- self.assertEqual(result, "Hello I am very excited to meet you!")
+ self.assertEqual(result.text, "Hello I am very excited to meet you!")
# Buffer should be empty
self.assertEqual(self.aggregator.text.text, "")
@@ -149,10 +184,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Handler should be called with entire content
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
- self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
+ self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
# Pattern should be removed, resulting in text with sentences merged
- self.assertEqual(result, "Hello Final sentence.")
+ self.assertEqual(result.text, "Hello Final sentence.")
# Buffer should be empty
self.assertEqual(self.aggregator.text.text, "")
diff --git a/tests/test_piper_tts.py b/tests/test_piper_tts.py
index 75893f93f..209e6e76c 100644
--- a/tests/test_piper_tts.py
+++ b/tests/test_piper_tts.py
@@ -74,6 +74,7 @@ async def test_run_piper_tts_success(aiohttp_client):
]
expected_returned_frames = [
+ TTSTextFrame,
TTSStartedFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
@@ -121,7 +122,7 @@ async def test_run_piper_tts_error(aiohttp_client):
TTSSpeakFrame(text="Error case."),
]
- expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
+ expected_down_frames = [TTSTextFrame, TTSStoppedFrame, TTSTextFrame]
expected_up_frames = [ErrorFrame]
diff --git a/tests/test_simple_text_aggregator.py b/tests/test_simple_text_aggregator.py
index ff6dd1847..007549458 100644
--- a/tests/test_simple_text_aggregator.py
+++ b/tests/test_simple_text_aggregator.py
@@ -15,15 +15,20 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reset_aggregations(self):
assert await self.aggregator.aggregate("Hello ") == None
- assert self.aggregator.text == "Hello "
+ assert self.aggregator.text.text == "Hello "
await self.aggregator.reset()
- assert self.aggregator.text == ""
+ assert self.aggregator.text.text == ""
async def test_simple_sentence(self):
assert await self.aggregator.aggregate("Hello ") == None
- assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
- assert self.aggregator.text == ""
+ aggregate = await self.aggregator.aggregate("Pipecat!")
+ assert aggregate.text == "Hello Pipecat!"
+ assert aggregate.type == "sentence"
+ assert self.aggregator.text.text == ""
async def test_multiple_sentences(self):
- assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
- assert await self.aggregator.aggregate("you?") == " How are you?"
+ aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
+ assert aggregate.text == "Hello Pipecat!"
+ assert self.aggregator.text.text == " How are "
+ aggregate = await self.aggregator.aggregate("you?")
+ assert aggregate.text == " How are you?"