From dcc20f86e184696bb90823be5277d6f2c4c6f736 Mon Sep 17 00:00:00 2001 From: mattie ruth backman Date: Fri, 14 Nov 2025 18:11:55 -0500 Subject: [PATCH] Updated the BaseTextAggregator to categorize aggregations Modified the BaseTextAggregator type so that when text gets aggregated, metadata can be associated with it. Currently, that just means a `type`, so that the aggregation can be classified or described. Changes made to support this: - **IMPORTANT**: Aggregators are now expected to strip leading/trailing white space characters before returning their aggregation from `aggregation()` or `.text`. This way all aggregators have a consistent contract allowing downstream use to know how to stitch aggregations back together - Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom aggregation") - **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`). To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text` - **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]` (instead of `Optional[str]`). To update: ``` aggregation = myAggregator.aggregate(text) if (aggregation): print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation} ``` - `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to produce/consume `Aggregation` objects. - All uses of the above Aggregators have been updated accordingly. --- CHANGELOG.md | 24 ++++++++ src/pipecat/extensions/ivr/ivr_navigator.py | 2 +- src/pipecat/services/tts_service.py | 29 ++++----- src/pipecat/tests/utils.py | 12 +++- .../utils/text/base_text_aggregator.py | 60 ++++++++++++++++--- .../utils/text/pattern_pair_aggregator.py | 15 ++--- .../utils/text/simple_text_aggregator.py | 12 ++-- .../utils/text/skip_tags_aggregator.py | 15 ++--- tests/test_pattern_pair_aggregator.py | 30 ++++++---- tests/test_simple_text_aggregator.py | 18 ++++-- tests/test_skip_tags_aggregator.py | 24 ++++---- 11 files changed, 166 insertions(+), 75 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 843fd98c4..497e97f2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian, Swedish, and Tamil -- Added new emotions: calm and fluent +- `BaseTextAggregator` changes: + Modified the BaseTextAggregator type so that when text gets aggregated, metadata can + be associated with it. Currently, that just means a `type`, so that the aggregation + can be classified or described. Changes made to support this: + - **IMPORTANT**: Aggregators are now expected to strip leading/trailing white space + characters before returning their aggregation from `aggregation()` or `.text`. This + way all aggregators have a consistent contract allowing downstream use to know how + to stitch aggregations back together. + - Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and + a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom + aggregation") + - **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`). + To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text` + - **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]` + (instead of `Optional[str]`). To update: + ``` + aggregation = myAggregator.aggregate(text) + if (aggregation): + print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation} + ``` + - `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to + produce/consume `Aggregation` objects. + - All uses of the above Aggregators have been updated accordingly. + ### Deprecated - The `api_key` parameter in `GeminiTTSService` is deprecated. Use diff --git a/src/pipecat/extensions/ivr/ivr_navigator.py b/src/pipecat/extensions/ivr/ivr_navigator.py index 05748d94f..9c59772ad 100644 --- a/src/pipecat/extensions/ivr/ivr_navigator.py +++ b/src/pipecat/extensions/ivr/ivr_navigator.py @@ -148,7 +148,7 @@ class IVRProcessor(FrameProcessor): result = await self._aggregator.aggregate(frame.text) if result: # Push aggregated text that doesn't contain XML patterns - await self.push_frame(LLMTextFrame(result), direction) + await self.push_frame(LLMTextFrame(result.text), direction) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index f0d602a40..33cf7d103 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -142,7 +142,6 @@ class TTSService(AIService): self._voice_id: str = "" self._settings: Dict[str, Any] = {} self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator() - self._aggregated_text_includes_inter_frame_spaces: bool = False self._text_filters: Sequence[BaseTextFilter] = text_filters or [] self._transport_destination: Optional[str] = transport_destination self._tracing_enabled: bool = False @@ -352,17 +351,14 @@ class TTSService(AIService): # pause to avoid audio overlapping. await self._maybe_pause_frame_processing() - sentence = self._text_aggregator.text - includes_inter_frame_spaces = self._aggregated_text_includes_inter_frame_spaces + pending_aggregation = self._text_aggregator.text # Reset aggregator state await self._text_aggregator.reset() self._processing_text = False - self._aggregated_text_includes_inter_frame_spaces = False - await self._push_tts_frames( - sentence, includes_inter_frame_spaces=includes_inter_frame_spaces - ) + if pending_aggregation.text: + await self._push_tts_frames(pending_aggregation.text) if isinstance(frame, LLMFullResponseEndFrame): if self._push_text_frames: await self.push_frame(frame, direction) @@ -372,7 +368,7 @@ class TTSService(AIService): # Store if we were processing text or not so we can set it back. processing_text = self._processing_text # Assumption: text in TTSSpeakFrame does not include inter-frame spaces - await self._push_tts_frames(frame.text, includes_inter_frame_spaces=False) + await self._push_tts_frames(frame.text) # We pause processing incoming frames because we are sending data to # the TTS. We pause to avoid audio overlapping. await self._maybe_pause_frame_processing() @@ -462,21 +458,20 @@ class TTSService(AIService): async def _process_text_frame(self, frame: TextFrame): text: Optional[str] = None + includes_inter_frame_spaces: bool = False if not self._aggregate_sentences: text = frame.text + includes_inter_frame_spaces = frame.includes_inter_frame_spaces else: - text = await self._text_aggregator.aggregate(frame.text) - # Assumption: whether inter-frame spaces are included shouldn't - # change during aggregation, so we can just use the latest frame's - # value - self._aggregated_text_includes_inter_frame_spaces = frame.includes_inter_frame_spaces + aggregation = await self._text_aggregator.aggregate(frame.text) + text = aggregation.text if text: - await self._push_tts_frames( - text, includes_inter_frame_spaces=frame.includes_inter_frame_spaces - ) + await self._push_tts_frames(text, includes_inter_frame_spaces) - async def _push_tts_frames(self, text: str, includes_inter_frame_spaces: bool): + async def _push_tts_frames( + self, text: str, includes_inter_frame_spaces: Optional[bool] = False + ): # Remove leading newlines only text = text.lstrip("\n") diff --git a/src/pipecat/tests/utils.py b/src/pipecat/tests/utils.py index 6ccce4b31..94b8cb1a4 100644 --- a/src/pipecat/tests/utils.py +++ b/src/pipecat/tests/utils.py @@ -203,8 +203,16 @@ async def run_test( if not isinstance(frame, EndFrame) or not send_end_frame: received_down_frames.append(frame) - print("received DOWN frames =", received_down_frames) - print("expected DOWN frames =", expected_down_frames) + down_frames_printed = "[" + for frame in received_down_frames: + down_frames_printed += f"{frame.__class__.__name__}, " + down_frames_printed += "]" + expected_frames_printed = "[" + for frame in expected_down_frames: + expected_frames_printed += f"{frame.__name__}, " + expected_frames_printed += "]" + print("received DOWN frames =", down_frames_printed) + print("expected DOWN frames =", expected_frames_printed) assert len(received_down_frames) == len(expected_down_frames) diff --git a/src/pipecat/utils/text/base_text_aggregator.py b/src/pipecat/utils/text/base_text_aggregator.py index 27e50fff5..07cb6c097 100644 --- a/src/pipecat/utils/text/base_text_aggregator.py +++ b/src/pipecat/utils/text/base_text_aggregator.py @@ -12,9 +12,47 @@ aggregated text should be sent for speech synthesis. """ from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum from typing import Optional +class AggregationType(str, Enum): + """Built-in aggregation strings.""" + + SENTENCE = "sentence" + WORD = "word" + + def __str__(self): + return self.value + + +@dataclass +class Aggregation: + """Data class representing aggregated text and its type. + + An Aggregation object is created whenever a stream of text is aggregated by + a text aggregator. It contains the aggregated text and a type indicating + the nature of the aggregation. + + Parameters: + text: The aggregated text content. + type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', + 'my_custom_aggregation'). + """ + + text: str + type: str + + def __str__(self) -> str: + """Return a string representation of the aggregation. + + Returns: + A descriptive string showing the type and text of the aggregation. + """ + return f"Aggregation by {self.type}: {self.text}" + + class BaseTextAggregator(ABC): """Base class for text aggregators in the Pipecat framework. @@ -30,7 +68,7 @@ class BaseTextAggregator(ABC): @property @abstractmethod - def text(self) -> str: + def text(self) -> Aggregation: """Get the currently aggregated text. Subclasses must implement this property to return the text that has @@ -42,25 +80,33 @@ class BaseTextAggregator(ABC): pass @abstractmethod - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate the specified text with the currently accumulated text. This method should be implemented to define how the new text contributes - to the aggregation process. It returns the updated aggregated text if - it's ready to be processed, or None otherwise. + to the aggregation process. It returns the aggregated text and a string + describing how it was aggregated if it's ready to be processed, + or None otherwise. Subclasses should implement their specific logic for: - How to combine new text with existing accumulated text - When to consider the aggregated text ready for processing - What criteria determine text completion (e.g., sentence boundaries) + - When a completion occurs, the method should return an Aggregation object + containing the aggregated text and its type. The text should be stripped + of leading/trailing whitespace so that consumers can rely on a consistent + format. Args: - text: The text to be aggregated. + text: The text to be aggregated Returns: - The updated aggregated text if ready for processing, or None if more - text is needed before the aggregated content is ready. + An Aggregation object if ready for processing, or None if more + text is needed before the aggregated content is ready. If an Aggregation + object is returned, it should consist of the updated aggregated text, + stripped of leading/trailing whitespace, and a string indicating the + type of aggregation (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation'). """ pass diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py index ac074f2de..5c4cafd53 100644 --- a/src/pipecat/utils/text/pattern_pair_aggregator.py +++ b/src/pipecat/utils/text/pattern_pair_aggregator.py @@ -17,7 +17,7 @@ from typing import Awaitable, Callable, Optional, Tuple from loguru import logger from pipecat.utils.string import match_endofsentence -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator class PatternMatch: @@ -75,13 +75,13 @@ class PatternPairAggregator(BaseTextAggregator): self._handlers = {} @property - def text(self) -> str: + def text(self) -> Aggregation: """Get the currently buffered text. Returns: The current text buffer content that hasn't been processed yet. """ - return self._text + return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE) def add_pattern_pair( self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True @@ -208,7 +208,7 @@ class PatternPairAggregator(BaseTextAggregator): return False - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate text and process pattern pairs. This method adds the new text to the buffer, processes any complete pattern @@ -220,8 +220,9 @@ class PatternPairAggregator(BaseTextAggregator): text: New text to add to the buffer. Returns: - Processed text up to a sentence boundary, or None if more - text is needed to form a complete sentence or pattern. + An Aggregation object containing processed text up to a sentence boundary + and marked as SENTENCE type, or None if more text is needed to form a + complete sentence or pattern. """ # Add new text to buffer self._text += text @@ -244,7 +245,7 @@ class PatternPairAggregator(BaseTextAggregator): # Extract text up to the sentence boundary result = self._text[:eos_marker] self._text = self._text[eos_marker:] - return result + return Aggregation(text=result.strip(), type=AggregationType.SENTENCE) # No complete sentence found yet return None diff --git a/src/pipecat/utils/text/simple_text_aggregator.py b/src/pipecat/utils/text/simple_text_aggregator.py index f9eb7d83a..56eab7032 100644 --- a/src/pipecat/utils/text/simple_text_aggregator.py +++ b/src/pipecat/utils/text/simple_text_aggregator.py @@ -14,7 +14,7 @@ text processing scenarios. from typing import Optional from pipecat.utils.string import match_endofsentence -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator class SimpleTextAggregator(BaseTextAggregator): @@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator): self._text = "" @property - def text(self) -> str: + def text(self) -> Aggregation: """Get the currently aggregated text. Returns: The text that has been accumulated in the buffer. """ - return self._text + return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE) - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate text and return completed sentences. Adds the new text to the buffer and checks for end-of-sentence markers. @@ -64,7 +64,9 @@ class SimpleTextAggregator(BaseTextAggregator): result = self._text[:eos_end_marker] self._text = self._text[eos_end_marker:] - return result + if result: + return Aggregation(text=result.strip(), type=AggregationType.SENTENCE) + return None async def handle_interruption(self): """Handle interruptions by clearing the text buffer. diff --git a/src/pipecat/utils/text/skip_tags_aggregator.py b/src/pipecat/utils/text/skip_tags_aggregator.py index 6f6f8455c..3c8b95aab 100644 --- a/src/pipecat/utils/text/skip_tags_aggregator.py +++ b/src/pipecat/utils/text/skip_tags_aggregator.py @@ -14,7 +14,7 @@ as a unit regardless of internal punctuation. from typing import Optional, Sequence from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags -from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator class SkipTagsAggregator(BaseTextAggregator): @@ -43,15 +43,15 @@ class SkipTagsAggregator(BaseTextAggregator): self._current_tag_index: int = 0 @property - def text(self) -> str: + def text(self) -> Aggregation: """Get the currently buffered text. Returns: The current text buffer content that hasn't been processed yet. """ - return self._text + return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE) - async def aggregate(self, text: str) -> Optional[str]: + async def aggregate(self, text: str) -> Optional[Aggregation]: """Aggregate text while respecting tag boundaries. This method adds the new text to the buffer, processes any complete @@ -63,8 +63,9 @@ class SkipTagsAggregator(BaseTextAggregator): text: New text to add to the buffer. Returns: - Processed text up to a sentence boundary (when not within tags), - or None if more text is needed to complete a sentence or close tags. + An Aggregation object containing text up to a sentence boundary and + marked as SENTENCE type or None if more text is needed to complete a + sentence or close tags. """ # Add new text to buffer self._text += text @@ -80,7 +81,7 @@ class SkipTagsAggregator(BaseTextAggregator): # Extract text up to the sentence boundary result = self._text[:eos_marker] self._text = self._text[eos_marker:] - return result + return Aggregation(text=result.strip(), type=AggregationType.SENTENCE) # No complete sentence found yet return None diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py index 8426dcf39..01e45f9bd 100644 --- a/tests/test_pattern_pair_aggregator.py +++ b/tests/test_pattern_pair_aggregator.py @@ -30,7 +30,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): # First part doesn't complete the pattern result = await self.aggregator.aggregate("Hello pattern") self.assertIsNone(result) - self.assertEqual(self.aggregator.text, "Hello pattern") + self.assertEqual(self.aggregator.text.text, "Hello pattern") # Second part completes the pattern and includes an exclamation point result = await self.aggregator.aggregate(" content!") @@ -45,14 +45,16 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): # The exclamation point should be treated as a sentence boundary, # so the result should include just text up to and including "!" - self.assertEqual(result, "Hello !") + self.assertEqual(result.text, "Hello !") + self.assertEqual(result.type, "sentence") - # Next sentence should be processed separately + # Next sentence should be processed separately. Spaces around the sentence + # should be stripped in the returned Aggregation. result = await self.aggregator.aggregate(" This is another sentence.") - self.assertEqual(result, " This is another sentence.") - + self.assertEqual(result.text, "This is another sentence.") + self.assertEqual(result.type, "sentence") # Buffer should be empty after returning a complete sentence - self.assertEqual(self.aggregator.text, "") + self.assertEqual(self.aggregator.text.text, "") async def test_incomplete_pattern(self): # Add text with incomplete pattern @@ -65,11 +67,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): self.test_handler.assert_not_called() # Buffer should contain the incomplete text - self.assertEqual(self.aggregator.text, "Hello pattern content") + self.assertEqual(self.aggregator.text.text, "Hello pattern content") # Reset and confirm buffer is cleared await self.aggregator.reset() - self.assertEqual(self.aggregator.text, "") + self.assertEqual(self.aggregator.text.text, "") async def test_multiple_patterns(self): # Set up multiple patterns and handlers @@ -106,10 +108,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): self.assertEqual(emphasis_match.content, "very") # Voice pattern should be removed, emphasis pattern should remain - self.assertEqual(result, "Hello I am very excited to meet you!") + self.assertEqual(result.text, "Hello I am very excited to meet you!") + self.assertEqual(result.type, "sentence") # Buffer should be empty - self.assertEqual(self.aggregator.text, "") + self.assertEqual(self.aggregator.text.text, "") async def test_handle_interruption(self): # Start with incomplete pattern @@ -120,7 +123,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): await self.aggregator.handle_interruption() # Buffer should be cleared - self.assertEqual(self.aggregator.text, "") + self.assertEqual(self.aggregator.text.text, "") # Handler should not have been called self.test_handler.assert_not_called() @@ -141,7 +144,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): self.assertEqual(call_args.content, "This is sentence one. This is sentence two.") # Pattern should be removed, resulting in text with sentences merged - self.assertEqual(result, "Hello Final sentence.") + self.assertEqual(result.text, "Hello Final sentence.") + self.assertEqual(result.type, "sentence") # Buffer should be empty - self.assertEqual(self.aggregator.text, "") + self.assertEqual(self.aggregator.text.text, "") diff --git a/tests/test_simple_text_aggregator.py b/tests/test_simple_text_aggregator.py index ff6dd1847..f8e2ee553 100644 --- a/tests/test_simple_text_aggregator.py +++ b/tests/test_simple_text_aggregator.py @@ -15,15 +15,21 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase): async def test_reset_aggregations(self): assert await self.aggregator.aggregate("Hello ") == None - assert self.aggregator.text == "Hello " + assert self.aggregator.text.text == "Hello" await self.aggregator.reset() - assert self.aggregator.text == "" + assert self.aggregator.text.text == "" async def test_simple_sentence(self): assert await self.aggregator.aggregate("Hello ") == None - assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!" - assert self.aggregator.text == "" + aggregate = await self.aggregator.aggregate("Pipecat!") + assert aggregate.text == "Hello Pipecat!" + assert aggregate.type == "sentence" + assert self.aggregator.text.text == "" async def test_multiple_sentences(self): - assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!" - assert await self.aggregator.aggregate("you?") == " How are you?" + aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ") + assert aggregate.text == "Hello Pipecat!" + # Aggregators should strip leading/trailing spaces when returning text + assert self.aggregator.text.text == "How are" + aggregate = await self.aggregator.aggregate("you?") + assert aggregate.text == "How are you?" diff --git a/tests/test_skip_tags_aggregator.py b/tests/test_skip_tags_aggregator.py index f6cbb7b93..702b991ce 100644 --- a/tests/test_skip_tags_aggregator.py +++ b/tests/test_skip_tags_aggregator.py @@ -18,16 +18,18 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase): # No tags involved, aggregate at end of sentence. result = await self.aggregator.aggregate("Hello Pipecat!") - self.assertEqual(result, "Hello Pipecat!") - self.assertEqual(self.aggregator.text, "") + self.assertEqual(result.text, "Hello Pipecat!") + self.assertEqual(result.type, "sentence") + self.assertEqual(self.aggregator.text.text, "") async def test_basic_tags(self): await self.aggregator.reset() # Tags involved, avoid aggregation during tags. result = await self.aggregator.aggregate("My email is foo@pipecat.ai.") - self.assertEqual(result, "My email is foo@pipecat.ai.") - self.assertEqual(self.aggregator.text, "") + self.assertEqual(result.text, "My email is foo@pipecat.ai.") + self.assertEqual(result.type, "sentence") + self.assertEqual(self.aggregator.text.text, "") async def test_streaming_tags(self): await self.aggregator.reset() @@ -35,20 +37,22 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase): # Tags involved, stream small chunk of texts. result = await self.aggregator.aggregate("My email is foo.") self.assertIsNone(result) - self.assertEqual(self.aggregator.text, "My email is foo.") + self.assertEqual(self.aggregator.text.text, "My email is foo.") result = await self.aggregator.aggregate("bar@pipecat.") self.assertIsNone(result) - self.assertEqual(self.aggregator.text, "My email is foo.bar@pipecat.") + self.assertEqual(self.aggregator.text.text, "My email is foo.bar@pipecat.") result = await self.aggregator.aggregate("aifoo.bar@pipecat.aifoo.bar@pipecat.ai.") - self.assertEqual(result, "My email is foo.bar@pipecat.ai.") - self.assertEqual(self.aggregator.text, "") + self.assertEqual(result.text, "My email is foo.bar@pipecat.ai.") + self.assertEqual(self.aggregator.text.text, "") + self.assertEqual(self.aggregator.text.type, "sentence")