Updated the BaseTextAggregator to categorize aggregations

Modified the BaseTextAggregator type so that when text gets aggregated, metadata can
be associated with it. Currently, that just means a `type`, so that the aggregation
can be classified or described. Changes made to support this:
  - **IMPORTANT**: Aggregators are now expected to strip leading/trailing white space
    characters before returning their aggregation from `aggregation()` or `.text`. This
    way all aggregators have a consistent contract allowing downstream use to know how
    to stitch aggregations back together
  - Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
    a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
    aggregation")
  - **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
    To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
  - **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
    (instead of `Optional[str]`). To update:
      ```
      aggregation = myAggregator.aggregate(text)
      if (aggregation):
        print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
      ```
  - `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
     produce/consume `Aggregation` objects.
  - All uses of the above Aggregators have been updated accordingly.
This commit is contained in:
mattie ruth backman
2025-11-14 18:11:55 -05:00
committed by Mattie Ruth
parent 26918728df
commit dcc20f86e1
11 changed files with 166 additions and 75 deletions

View File

@@ -84,6 +84,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian, Swedish, and Tamil
-- Added new emotions: calm and fluent
- `BaseTextAggregator` changes:
Modified the BaseTextAggregator type so that when text gets aggregated, metadata can
be associated with it. Currently, that just means a `type`, so that the aggregation
can be classified or described. Changes made to support this:
- **IMPORTANT**: Aggregators are now expected to strip leading/trailing white space
characters before returning their aggregation from `aggregation()` or `.text`. This
way all aggregators have a consistent contract allowing downstream use to know how
to stitch aggregations back together.
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
aggregation")
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
(instead of `Optional[str]`). To update:
```
aggregation = myAggregator.aggregate(text)
if (aggregation):
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
```
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
produce/consume `Aggregation` objects.
- All uses of the above Aggregators have been updated accordingly.
### Deprecated
- The `api_key` parameter in `GeminiTTSService` is deprecated. Use

View File

@@ -148,7 +148,7 @@ class IVRProcessor(FrameProcessor):
result = await self._aggregator.aggregate(frame.text)
if result:
# Push aggregated text that doesn't contain XML patterns
await self.push_frame(LLMTextFrame(result), direction)
await self.push_frame(LLMTextFrame(result.text), direction)
else:
await self.push_frame(frame, direction)

View File

@@ -142,7 +142,6 @@ class TTSService(AIService):
self._voice_id: str = ""
self._settings: Dict[str, Any] = {}
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
self._aggregated_text_includes_inter_frame_spaces: bool = False
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
self._transport_destination: Optional[str] = transport_destination
self._tracing_enabled: bool = False
@@ -352,17 +351,14 @@ class TTSService(AIService):
# pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
sentence = self._text_aggregator.text
includes_inter_frame_spaces = self._aggregated_text_includes_inter_frame_spaces
pending_aggregation = self._text_aggregator.text
# Reset aggregator state
await self._text_aggregator.reset()
self._processing_text = False
self._aggregated_text_includes_inter_frame_spaces = False
await self._push_tts_frames(
sentence, includes_inter_frame_spaces=includes_inter_frame_spaces
)
if pending_aggregation.text:
await self._push_tts_frames(pending_aggregation.text)
if isinstance(frame, LLMFullResponseEndFrame):
if self._push_text_frames:
await self.push_frame(frame, direction)
@@ -372,7 +368,7 @@ class TTSService(AIService):
# Store if we were processing text or not so we can set it back.
processing_text = self._processing_text
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
await self._push_tts_frames(frame.text, includes_inter_frame_spaces=False)
await self._push_tts_frames(frame.text)
# We pause processing incoming frames because we are sending data to
# the TTS. We pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
@@ -462,21 +458,20 @@ class TTSService(AIService):
async def _process_text_frame(self, frame: TextFrame):
text: Optional[str] = None
includes_inter_frame_spaces: bool = False
if not self._aggregate_sentences:
text = frame.text
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
else:
text = await self._text_aggregator.aggregate(frame.text)
# Assumption: whether inter-frame spaces are included shouldn't
# change during aggregation, so we can just use the latest frame's
# value
self._aggregated_text_includes_inter_frame_spaces = frame.includes_inter_frame_spaces
aggregation = await self._text_aggregator.aggregate(frame.text)
text = aggregation.text
if text:
await self._push_tts_frames(
text, includes_inter_frame_spaces=frame.includes_inter_frame_spaces
)
await self._push_tts_frames(text, includes_inter_frame_spaces)
async def _push_tts_frames(self, text: str, includes_inter_frame_spaces: bool):
async def _push_tts_frames(
self, text: str, includes_inter_frame_spaces: Optional[bool] = False
):
# Remove leading newlines only
text = text.lstrip("\n")

View File

@@ -203,8 +203,16 @@ async def run_test(
if not isinstance(frame, EndFrame) or not send_end_frame:
received_down_frames.append(frame)
print("received DOWN frames =", received_down_frames)
print("expected DOWN frames =", expected_down_frames)
down_frames_printed = "["
for frame in received_down_frames:
down_frames_printed += f"{frame.__class__.__name__}, "
down_frames_printed += "]"
expected_frames_printed = "["
for frame in expected_down_frames:
expected_frames_printed += f"{frame.__name__}, "
expected_frames_printed += "]"
print("received DOWN frames =", down_frames_printed)
print("expected DOWN frames =", expected_frames_printed)
assert len(received_down_frames) == len(expected_down_frames)

View File

@@ -12,9 +12,47 @@ aggregated text should be sent for speech synthesis.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class AggregationType(str, Enum):
"""Built-in aggregation strings."""
SENTENCE = "sentence"
WORD = "word"
def __str__(self):
return self.value
@dataclass
class Aggregation:
"""Data class representing aggregated text and its type.
An Aggregation object is created whenever a stream of text is aggregated by
a text aggregator. It contains the aggregated text and a type indicating
the nature of the aggregation.
Parameters:
text: The aggregated text content.
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token',
'my_custom_aggregation').
"""
text: str
type: str
def __str__(self) -> str:
"""Return a string representation of the aggregation.
Returns:
A descriptive string showing the type and text of the aggregation.
"""
return f"Aggregation by {self.type}: {self.text}"
class BaseTextAggregator(ABC):
"""Base class for text aggregators in the Pipecat framework.
@@ -30,7 +68,7 @@ class BaseTextAggregator(ABC):
@property
@abstractmethod
def text(self) -> str:
def text(self) -> Aggregation:
"""Get the currently aggregated text.
Subclasses must implement this property to return the text that has
@@ -42,25 +80,33 @@ class BaseTextAggregator(ABC):
pass
@abstractmethod
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate the specified text with the currently accumulated text.
This method should be implemented to define how the new text contributes
to the aggregation process. It returns the updated aggregated text if
it's ready to be processed, or None otherwise.
to the aggregation process. It returns the aggregated text and a string
describing how it was aggregated if it's ready to be processed,
or None otherwise.
Subclasses should implement their specific logic for:
- How to combine new text with existing accumulated text
- When to consider the aggregated text ready for processing
- What criteria determine text completion (e.g., sentence boundaries)
- When a completion occurs, the method should return an Aggregation object
containing the aggregated text and its type. The text should be stripped
of leading/trailing whitespace so that consumers can rely on a consistent
format.
Args:
text: The text to be aggregated.
text: The text to be aggregated
Returns:
The updated aggregated text if ready for processing, or None if more
text is needed before the aggregated content is ready.
An Aggregation object if ready for processing, or None if more
text is needed before the aggregated content is ready. If an Aggregation
object is returned, it should consist of the updated aggregated text,
stripped of leading/trailing whitespace, and a string indicating the
type of aggregation (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
"""
pass

View File

@@ -17,7 +17,7 @@ from typing import Awaitable, Callable, Optional, Tuple
from loguru import logger
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class PatternMatch:
@@ -75,13 +75,13 @@ class PatternPairAggregator(BaseTextAggregator):
self._handlers = {}
@property
def text(self) -> str:
def text(self) -> Aggregation:
"""Get the currently buffered text.
Returns:
The current text buffer content that hasn't been processed yet.
"""
return self._text
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
def add_pattern_pair(
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
@@ -208,7 +208,7 @@ class PatternPairAggregator(BaseTextAggregator):
return False
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate text and process pattern pairs.
This method adds the new text to the buffer, processes any complete pattern
@@ -220,8 +220,9 @@ class PatternPairAggregator(BaseTextAggregator):
text: New text to add to the buffer.
Returns:
Processed text up to a sentence boundary, or None if more
text is needed to form a complete sentence or pattern.
An Aggregation object containing processed text up to a sentence boundary
and marked as SENTENCE type, or None if more text is needed to form a
complete sentence or pattern.
"""
# Add new text to buffer
self._text += text
@@ -244,7 +245,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return result
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
# No complete sentence found yet
return None

View File

@@ -14,7 +14,7 @@ text processing scenarios.
from typing import Optional
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class SimpleTextAggregator(BaseTextAggregator):
@@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator):
self._text = ""
@property
def text(self) -> str:
def text(self) -> Aggregation:
"""Get the currently aggregated text.
Returns:
The text that has been accumulated in the buffer.
"""
return self._text
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate text and return completed sentences.
Adds the new text to the buffer and checks for end-of-sentence markers.
@@ -64,7 +64,9 @@ class SimpleTextAggregator(BaseTextAggregator):
result = self._text[:eos_end_marker]
self._text = self._text[eos_end_marker:]
return result
if result:
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
return None
async def handle_interruption(self):
"""Handle interruptions by clearing the text buffer.

View File

@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
from typing import Optional, Sequence
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class SkipTagsAggregator(BaseTextAggregator):
@@ -43,15 +43,15 @@ class SkipTagsAggregator(BaseTextAggregator):
self._current_tag_index: int = 0
@property
def text(self) -> str:
def text(self) -> Aggregation:
"""Get the currently buffered text.
Returns:
The current text buffer content that hasn't been processed yet.
"""
return self._text
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate text while respecting tag boundaries.
This method adds the new text to the buffer, processes any complete
@@ -63,8 +63,9 @@ class SkipTagsAggregator(BaseTextAggregator):
text: New text to add to the buffer.
Returns:
Processed text up to a sentence boundary (when not within tags),
or None if more text is needed to complete a sentence or close tags.
An Aggregation object containing text up to a sentence boundary and
marked as SENTENCE type or None if more text is needed to complete a
sentence or close tags.
"""
# Add new text to buffer
self._text += text
@@ -80,7 +81,7 @@ class SkipTagsAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return result
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
# No complete sentence found yet
return None

View File

@@ -30,7 +30,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Hello <test>pattern")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</test>!")
@@ -45,14 +45,16 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# The exclamation point should be treated as a sentence boundary,
# so the result should include just text up to and including "!"
self.assertEqual(result, "Hello !")
self.assertEqual(result.text, "Hello !")
self.assertEqual(result.type, "sentence")
# Next sentence should be processed separately
# Next sentence should be processed separately. Spaces around the sentence
# should be stripped in the returned Aggregation.
result = await self.aggregator.aggregate(" This is another sentence.")
self.assertEqual(result, " This is another sentence.")
self.assertEqual(result.text, "This is another sentence.")
self.assertEqual(result.type, "sentence")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_incomplete_pattern(self):
# Add text with incomplete pattern
@@ -65,11 +67,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_not_called()
# Buffer should contain the incomplete text
self.assertEqual(self.aggregator.text, "Hello <test>pattern content")
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
# Reset and confirm buffer is cleared
await self.aggregator.reset()
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_multiple_patterns(self):
# Set up multiple patterns and handlers
@@ -106,10 +108,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(emphasis_match.content, "very")
# Voice pattern should be removed, emphasis pattern should remain
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
self.assertEqual(result.type, "sentence")
# Buffer should be empty
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_handle_interruption(self):
# Start with incomplete pattern
@@ -120,7 +123,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
await self.aggregator.handle_interruption()
# Buffer should be cleared
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
# Handler should not have been called
self.test_handler.assert_not_called()
@@ -141,7 +144,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
# Pattern should be removed, resulting in text with sentences merged
self.assertEqual(result, "Hello Final sentence.")
self.assertEqual(result.text, "Hello Final sentence.")
self.assertEqual(result.type, "sentence")
# Buffer should be empty
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")

View File

@@ -15,15 +15,21 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reset_aggregations(self):
assert await self.aggregator.aggregate("Hello ") == None
assert self.aggregator.text == "Hello "
assert self.aggregator.text.text == "Hello"
await self.aggregator.reset()
assert self.aggregator.text == ""
assert self.aggregator.text.text == ""
async def test_simple_sentence(self):
assert await self.aggregator.aggregate("Hello ") == None
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
assert self.aggregator.text == ""
aggregate = await self.aggregator.aggregate("Pipecat!")
assert aggregate.text == "Hello Pipecat!"
assert aggregate.type == "sentence"
assert self.aggregator.text.text == ""
async def test_multiple_sentences(self):
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
assert await self.aggregator.aggregate("you?") == " How are you?"
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
assert aggregate.text == "Hello Pipecat!"
# Aggregators should strip leading/trailing spaces when returning text
assert self.aggregator.text.text == "How are"
aggregate = await self.aggregator.aggregate("you?")
assert aggregate.text == "How are you?"

View File

@@ -18,16 +18,18 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
# No tags involved, aggregate at end of sentence.
result = await self.aggregator.aggregate("Hello Pipecat!")
self.assertEqual(result, "Hello Pipecat!")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "Hello Pipecat!")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
async def test_basic_tags(self):
await self.aggregator.reset()
# Tags involved, avoid aggregation during tags.
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
async def test_streaming_tags(self):
await self.aggregator.reset()
@@ -35,20 +37,22 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
# Tags involved, stream small chunk of texts.
result = await self.aggregator.aggregate("My email is <sp")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <sp")
self.assertEqual(self.aggregator.text.text, "My email is <sp")
result = await self.aggregator.aggregate("ell>foo.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
result = await self.aggregator.aggregate("bar@pipecat.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
result = await self.aggregator.aggregate("ai</spe")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.ai</spe")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
self.assertEqual(self.aggregator.text.type, "sentence")
result = await self.aggregator.aggregate("ll>.")
self.assertEqual(result, "My email is <spell>foo.bar@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text.type, "sentence")