Move the sentence vs token aggregation concern into text aggregators so all text flows through them regardless of mode. This enables pattern detection and tag handling to work in TOKEN mode. - Add TextAggregationMode enum (SENTENCE, TOKEN) as the user-facing TTS setting, separate from the internal AggregationType - Add TOKEN mode support to Simple, SkipTags, and PatternPair aggregators - Add text_aggregation_mode parameter to TTSService and all TTS subclasses - Deprecate aggregate_sentences in favor of text_aggregation_mode - Merge TTSService._process_text_frame() into a single codepath
260 lines
10 KiB
Python
260 lines
10 KiB
Python
#
|
|
# Copyright (c) 2024-2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
import unittest
|
|
from unittest.mock import AsyncMock
|
|
|
|
from pipecat.utils.text.pattern_pair_aggregator import (
|
|
MatchAction,
|
|
PatternMatch,
|
|
PatternPairAggregator,
|
|
)
|
|
|
|
|
|
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self):
|
|
self.aggregator = PatternPairAggregator()
|
|
self.test_handler = AsyncMock()
|
|
self.code_handler = AsyncMock()
|
|
|
|
# Add a test pattern
|
|
self.aggregator.add_pattern_pair(
|
|
pattern_id="test_pattern",
|
|
start_pattern="<test>",
|
|
end_pattern="</test>",
|
|
)
|
|
self.aggregator.add_pattern(
|
|
type="code_pattern",
|
|
start_pattern="<code>",
|
|
end_pattern="</code>",
|
|
action=MatchAction.AGGREGATE,
|
|
)
|
|
|
|
# Register the mock handler
|
|
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
|
|
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
|
|
|
|
async def test_pattern_match_and_removal(self):
|
|
text = "Hello <test>pattern content</test>!"
|
|
results = [result async for result in self.aggregator.aggregate(text)]
|
|
|
|
# Verify the handler was called with correct PatternMatch object
|
|
self.test_handler.assert_called_once()
|
|
call_args = self.test_handler.call_args[0][0]
|
|
self.assertIsInstance(call_args, PatternMatch)
|
|
self.assertEqual(call_args.type, "test_pattern")
|
|
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
|
|
self.assertEqual(call_args.text, "pattern content")
|
|
|
|
# No results yet (waiting for lookahead after "!")
|
|
self.assertEqual(len(results), 0)
|
|
|
|
# Next sentence should provide the lookahead and trigger the previous sentence
|
|
async for result in self.aggregator.aggregate(" This is another sentence."):
|
|
results.append(result)
|
|
|
|
# First result should be "Hello !" triggered by the space lookahead
|
|
self.assertEqual(len(results), 1)
|
|
self.assertEqual(results[0].text, "Hello !")
|
|
self.assertEqual(results[0].type, "sentence")
|
|
|
|
# Now flush to get the remaining sentence
|
|
result = await self.aggregator.flush()
|
|
self.assertEqual(result.text, "This is another sentence.")
|
|
|
|
# Buffer should be empty after returning a complete sentence
|
|
self.assertEqual(self.aggregator.text.text, "")
|
|
|
|
async def test_pattern_match_and_aggregate(self):
|
|
text = "Here is code <code>pattern content</code> This is another sentence."
|
|
|
|
results = [result async for result in self.aggregator.aggregate(text)]
|
|
|
|
# First result should be "Here is code" when pattern starts
|
|
self.assertEqual(results[0].text, "Here is code")
|
|
self.assertEqual(results[0].type, "sentence")
|
|
|
|
# Second result should be the code pattern content
|
|
self.assertEqual(results[1].text, "pattern content")
|
|
self.assertEqual(results[1].type, "code_pattern")
|
|
|
|
# Verify the handler was called with correct PatternMatch object
|
|
self.code_handler.assert_called_once()
|
|
call_args = self.code_handler.call_args[0][0]
|
|
self.assertIsInstance(call_args, PatternMatch)
|
|
self.assertEqual(call_args.type, "code_pattern")
|
|
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
|
|
self.assertEqual(call_args.text, "pattern content")
|
|
|
|
# Last sentence needs flush (waiting for lookahead after ".")
|
|
result = await self.aggregator.flush()
|
|
self.assertEqual(result.text, "This is another sentence.")
|
|
self.assertEqual(result.type, "sentence")
|
|
|
|
# Buffer should be empty after returning a complete sentence
|
|
self.assertEqual(self.aggregator.text.text, "")
|
|
|
|
async def test_incomplete_pattern(self):
|
|
text = "Hello <test>pattern content"
|
|
results = [result async for result in self.aggregator.aggregate(text)]
|
|
# No complete pattern yet, so nothing should be returned
|
|
self.assertEqual(len(results), 0)
|
|
|
|
# The handler should not be called yet
|
|
self.test_handler.assert_not_called()
|
|
|
|
# Buffer should contain the incomplete text
|
|
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
|
|
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
|
|
|
# Reset and confirm buffer is cleared
|
|
await self.aggregator.reset()
|
|
self.assertEqual(self.aggregator.text.text, "")
|
|
|
|
async def test_multiple_patterns(self):
|
|
# Set up multiple patterns and handlers
|
|
voice_handler = AsyncMock()
|
|
emphasis_handler = AsyncMock()
|
|
|
|
self.aggregator.add_pattern(
|
|
type="voice",
|
|
start_pattern="<voice>",
|
|
end_pattern="</voice>",
|
|
action=MatchAction.REMOVE,
|
|
)
|
|
|
|
self.aggregator.add_pattern(
|
|
type="emphasis",
|
|
start_pattern="<em>",
|
|
end_pattern="</em>",
|
|
action=MatchAction.KEEP, # Keep emphasis tags
|
|
)
|
|
|
|
self.aggregator.on_pattern_match("voice", voice_handler)
|
|
self.aggregator.on_pattern_match("emphasis", emphasis_handler)
|
|
|
|
text = "Hello <voice>female</voice> I am <em>very</em> excited to meet you!"
|
|
results = [result async for result in self.aggregator.aggregate(text)]
|
|
|
|
# Both handlers should be called with correct data
|
|
voice_handler.assert_called_once()
|
|
voice_match = voice_handler.call_args[0][0]
|
|
self.assertEqual(voice_match.type, "voice")
|
|
self.assertEqual(voice_match.text, "female")
|
|
|
|
emphasis_handler.assert_called_once()
|
|
emphasis_match = emphasis_handler.call_args[0][0]
|
|
self.assertEqual(emphasis_match.type, "emphasis")
|
|
self.assertEqual(emphasis_match.text, "very")
|
|
|
|
# With lookahead, we need to flush to get the final sentence
|
|
self.assertEqual(len(results), 0) # Waiting for lookahead after "!"
|
|
|
|
result = await self.aggregator.flush()
|
|
# Voice pattern should be removed, emphasis pattern should remain
|
|
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
|
|
|
|
# Buffer should be empty
|
|
self.assertEqual(self.aggregator.text.text, "")
|
|
|
|
async def test_handle_interruption(self):
|
|
text = "Hello <test>pattern"
|
|
results = [result async for result in self.aggregator.aggregate(text)]
|
|
self.assertEqual(len(results), 0)
|
|
|
|
# Simulate interruption
|
|
await self.aggregator.handle_interruption()
|
|
|
|
# Buffer should be cleared
|
|
self.assertEqual(self.aggregator.text.text, "")
|
|
|
|
# Handler should not have been called
|
|
self.test_handler.assert_not_called()
|
|
|
|
async def test_pattern_across_sentences(self):
|
|
text = "Hello <test>This is sentence one. This is sentence two.</test> Final sentence."
|
|
results = [result async for result in self.aggregator.aggregate(text)]
|
|
|
|
# Handler should be called with entire content
|
|
self.test_handler.assert_called_once()
|
|
call_args = self.test_handler.call_args[0][0]
|
|
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
|
|
|
|
# With lookahead, we need to flush to get the final sentence
|
|
self.assertEqual(len(results), 0) # Waiting for lookahead after "."
|
|
|
|
result = await self.aggregator.flush()
|
|
# Pattern should be removed, resulting in text with sentences merged
|
|
self.assertEqual(result.text, "Hello Final sentence.")
|
|
|
|
# Buffer should be empty
|
|
self.assertEqual(self.aggregator.text.text, "")
|
|
|
|
|
|
class TestPatternPairAggregatorTokenMode(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self):
|
|
from pipecat.utils.text.base_text_aggregator import AggregationType
|
|
|
|
self.aggregator = PatternPairAggregator(aggregation_type=AggregationType.TOKEN)
|
|
self.handler = AsyncMock()
|
|
self.aggregator.add_pattern(
|
|
type="think",
|
|
start_pattern="<think>",
|
|
end_pattern="</think>",
|
|
action=MatchAction.REMOVE,
|
|
)
|
|
self.aggregator.on_pattern_match("think", self.handler)
|
|
|
|
async def test_token_no_patterns(self):
|
|
"""Non-pattern text passes through as TOKEN, one per aggregate call."""
|
|
results = []
|
|
for token in ["Hello", " world", "."]:
|
|
async for r in self.aggregator.aggregate(token):
|
|
results.append(r)
|
|
|
|
self.assertEqual(len(results), 3)
|
|
self.assertEqual(results[0].text, "Hello")
|
|
self.assertEqual(results[1].text, " world")
|
|
self.assertEqual(results[2].text, ".")
|
|
for r in results:
|
|
self.assertEqual(r.type, "token")
|
|
|
|
async def test_token_pattern_detection(self):
|
|
"""Pattern detection still works with word-by-word token delivery."""
|
|
results = []
|
|
for token in ["Hi ", "<think>", "secret", "</think>", " bye"]:
|
|
async for r in self.aggregator.aggregate(token):
|
|
results.append(r)
|
|
|
|
# Handler called once when the pattern completes
|
|
self.handler.assert_called_once()
|
|
call_args = self.handler.call_args[0][0]
|
|
self.assertEqual(call_args.text, "secret")
|
|
|
|
# "Hi " yields before pattern starts, pattern is removed, " bye" yields after
|
|
self.assertEqual(len(results), 2)
|
|
self.assertEqual(results[0].text, "Hi ")
|
|
self.assertEqual(results[0].type, "token")
|
|
self.assertEqual(results[1].text, " bye")
|
|
self.assertEqual(results[1].type, "token")
|
|
|
|
async def test_token_incomplete_pattern_buffers(self):
|
|
"""Incomplete pattern is buffered across calls, not leaked to output."""
|
|
results = []
|
|
for token in ["Hi ", "<think>", "partial"]:
|
|
async for r in self.aggregator.aggregate(token):
|
|
results.append(r)
|
|
|
|
# Only "Hi " should be yielded; "<think>partial" stays buffered
|
|
self.assertEqual(len(results), 1)
|
|
self.assertEqual(results[0].text, "Hi ")
|
|
self.assertEqual(results[0].type, "token")
|
|
self.handler.assert_not_called()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|