Updated the BaseTextAggregator to categorize aggregations
Modified the BaseTextAggregator type so that when text gets aggregated, metadata can
be associated with it. Currently, that just means a `type`, so that the aggregation
can be classified or described. Changes made to support this:
- **IMPORTANT**: Aggregators are now expected to strip leading/trailing white space
characters before returning their aggregation from `aggregation()` or `.text`. This
way all aggregators have a consistent contract allowing downstream use to know how
to stitch aggregations back together
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
aggregation")
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
(instead of `Optional[str]`). To update:
```
aggregation = myAggregator.aggregate(text)
if (aggregation):
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
```
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
produce/consume `Aggregation` objects.
- All uses of the above Aggregators have been updated accordingly.
This commit is contained in:
committed by
Mattie Ruth
parent
26918728df
commit
dcc20f86e1
24
CHANGELOG.md
24
CHANGELOG.md
@@ -84,6 +84,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian, Swedish, and Tamil
|
||||
-- Added new emotions: calm and fluent
|
||||
|
||||
- `BaseTextAggregator` changes:
|
||||
Modified the BaseTextAggregator type so that when text gets aggregated, metadata can
|
||||
be associated with it. Currently, that just means a `type`, so that the aggregation
|
||||
can be classified or described. Changes made to support this:
|
||||
- **IMPORTANT**: Aggregators are now expected to strip leading/trailing white space
|
||||
characters before returning their aggregation from `aggregation()` or `.text`. This
|
||||
way all aggregators have a consistent contract allowing downstream use to know how
|
||||
to stitch aggregations back together.
|
||||
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
|
||||
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
|
||||
aggregation")
|
||||
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
|
||||
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
|
||||
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
|
||||
(instead of `Optional[str]`). To update:
|
||||
```
|
||||
aggregation = myAggregator.aggregate(text)
|
||||
if (aggregation):
|
||||
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
|
||||
```
|
||||
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
|
||||
produce/consume `Aggregation` objects.
|
||||
- All uses of the above Aggregators have been updated accordingly.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The `api_key` parameter in `GeminiTTSService` is deprecated. Use
|
||||
|
||||
@@ -148,7 +148,7 @@ class IVRProcessor(FrameProcessor):
|
||||
result = await self._aggregator.aggregate(frame.text)
|
||||
if result:
|
||||
# Push aggregated text that doesn't contain XML patterns
|
||||
await self.push_frame(LLMTextFrame(result), direction)
|
||||
await self.push_frame(LLMTextFrame(result.text), direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -142,7 +142,6 @@ class TTSService(AIService):
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
self._aggregated_text_includes_inter_frame_spaces: bool = False
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
self._tracing_enabled: bool = False
|
||||
@@ -352,17 +351,14 @@ class TTSService(AIService):
|
||||
# pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
|
||||
sentence = self._text_aggregator.text
|
||||
includes_inter_frame_spaces = self._aggregated_text_includes_inter_frame_spaces
|
||||
pending_aggregation = self._text_aggregator.text
|
||||
|
||||
# Reset aggregator state
|
||||
await self._text_aggregator.reset()
|
||||
self._processing_text = False
|
||||
self._aggregated_text_includes_inter_frame_spaces = False
|
||||
|
||||
await self._push_tts_frames(
|
||||
sentence, includes_inter_frame_spaces=includes_inter_frame_spaces
|
||||
)
|
||||
if pending_aggregation.text:
|
||||
await self._push_tts_frames(pending_aggregation.text)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -372,7 +368,7 @@ class TTSService(AIService):
|
||||
# Store if we were processing text or not so we can set it back.
|
||||
processing_text = self._processing_text
|
||||
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
|
||||
await self._push_tts_frames(frame.text, includes_inter_frame_spaces=False)
|
||||
await self._push_tts_frames(frame.text)
|
||||
# We pause processing incoming frames because we are sending data to
|
||||
# the TTS. We pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
@@ -462,21 +458,20 @@ class TTSService(AIService):
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: Optional[str] = None
|
||||
includes_inter_frame_spaces: bool = False
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
|
||||
else:
|
||||
text = await self._text_aggregator.aggregate(frame.text)
|
||||
# Assumption: whether inter-frame spaces are included shouldn't
|
||||
# change during aggregation, so we can just use the latest frame's
|
||||
# value
|
||||
self._aggregated_text_includes_inter_frame_spaces = frame.includes_inter_frame_spaces
|
||||
aggregation = await self._text_aggregator.aggregate(frame.text)
|
||||
text = aggregation.text
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(
|
||||
text, includes_inter_frame_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
await self._push_tts_frames(text, includes_inter_frame_spaces)
|
||||
|
||||
async def _push_tts_frames(self, text: str, includes_inter_frame_spaces: bool):
|
||||
async def _push_tts_frames(
|
||||
self, text: str, includes_inter_frame_spaces: Optional[bool] = False
|
||||
):
|
||||
# Remove leading newlines only
|
||||
text = text.lstrip("\n")
|
||||
|
||||
|
||||
@@ -203,8 +203,16 @@ async def run_test(
|
||||
if not isinstance(frame, EndFrame) or not send_end_frame:
|
||||
received_down_frames.append(frame)
|
||||
|
||||
print("received DOWN frames =", received_down_frames)
|
||||
print("expected DOWN frames =", expected_down_frames)
|
||||
down_frames_printed = "["
|
||||
for frame in received_down_frames:
|
||||
down_frames_printed += f"{frame.__class__.__name__}, "
|
||||
down_frames_printed += "]"
|
||||
expected_frames_printed = "["
|
||||
for frame in expected_down_frames:
|
||||
expected_frames_printed += f"{frame.__name__}, "
|
||||
expected_frames_printed += "]"
|
||||
print("received DOWN frames =", down_frames_printed)
|
||||
print("expected DOWN frames =", expected_frames_printed)
|
||||
|
||||
assert len(received_down_frames) == len(expected_down_frames)
|
||||
|
||||
|
||||
@@ -12,9 +12,47 @@ aggregated text should be sent for speech synthesis.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aggregation:
|
||||
"""Data class representing aggregated text and its type.
|
||||
|
||||
An Aggregation object is created whenever a stream of text is aggregated by
|
||||
a text aggregator. It contains the aggregated text and a type indicating
|
||||
the nature of the aggregation.
|
||||
|
||||
Parameters:
|
||||
text: The aggregated text content.
|
||||
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token',
|
||||
'my_custom_aggregation').
|
||||
"""
|
||||
|
||||
text: str
|
||||
type: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the aggregation.
|
||||
|
||||
Returns:
|
||||
A descriptive string showing the type and text of the aggregation.
|
||||
"""
|
||||
return f"Aggregation by {self.type}: {self.text}"
|
||||
|
||||
|
||||
class BaseTextAggregator(ABC):
|
||||
"""Base class for text aggregators in the Pipecat framework.
|
||||
|
||||
@@ -30,7 +68,7 @@ class BaseTextAggregator(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def text(self) -> str:
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Subclasses must implement this property to return the text that has
|
||||
@@ -42,25 +80,33 @@ class BaseTextAggregator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate the specified text with the currently accumulated text.
|
||||
|
||||
This method should be implemented to define how the new text contributes
|
||||
to the aggregation process. It returns the updated aggregated text if
|
||||
it's ready to be processed, or None otherwise.
|
||||
to the aggregation process. It returns the aggregated text and a string
|
||||
describing how it was aggregated if it's ready to be processed,
|
||||
or None otherwise.
|
||||
|
||||
Subclasses should implement their specific logic for:
|
||||
|
||||
- How to combine new text with existing accumulated text
|
||||
- When to consider the aggregated text ready for processing
|
||||
- What criteria determine text completion (e.g., sentence boundaries)
|
||||
- When a completion occurs, the method should return an Aggregation object
|
||||
containing the aggregated text and its type. The text should be stripped
|
||||
of leading/trailing whitespace so that consumers can rely on a consistent
|
||||
format.
|
||||
|
||||
Args:
|
||||
text: The text to be aggregated.
|
||||
text: The text to be aggregated
|
||||
|
||||
Returns:
|
||||
The updated aggregated text if ready for processing, or None if more
|
||||
text is needed before the aggregated content is ready.
|
||||
An Aggregation object if ready for processing, or None if more
|
||||
text is needed before the aggregated content is ready. If an Aggregation
|
||||
object is returned, it should consist of the updated aggregated text,
|
||||
stripped of leading/trailing whitespace, and a string indicating the
|
||||
type of aggregation (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Awaitable, Callable, Optional, Tuple
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class PatternMatch:
|
||||
@@ -75,13 +75,13 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._handlers = {}
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently buffered text.
|
||||
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
"""
|
||||
return self._text
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
def add_pattern_pair(
|
||||
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
|
||||
@@ -208,7 +208,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
return False
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text and process pattern pairs.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete pattern
|
||||
@@ -220,8 +220,9 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
text: New text to add to the buffer.
|
||||
|
||||
Returns:
|
||||
Processed text up to a sentence boundary, or None if more
|
||||
text is needed to form a complete sentence or pattern.
|
||||
An Aggregation object containing processed text up to a sentence boundary
|
||||
and marked as SENTENCE type, or None if more text is needed to form a
|
||||
complete sentence or pattern.
|
||||
"""
|
||||
# Add new text to buffer
|
||||
self._text += text
|
||||
@@ -244,7 +245,7 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return result
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -14,7 +14,7 @@ text processing scenarios.
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class SimpleTextAggregator(BaseTextAggregator):
|
||||
@@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
self._text = ""
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently aggregated text.
|
||||
|
||||
Returns:
|
||||
The text that has been accumulated in the buffer.
|
||||
"""
|
||||
return self._text
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text and return completed sentences.
|
||||
|
||||
Adds the new text to the buffer and checks for end-of-sentence markers.
|
||||
@@ -64,7 +64,9 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
result = self._text[:eos_end_marker]
|
||||
self._text = self._text[eos_end_marker:]
|
||||
|
||||
return result
|
||||
if result:
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
return None
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the text buffer.
|
||||
|
||||
@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
class SkipTagsAggregator(BaseTextAggregator):
|
||||
@@ -43,15 +43,15 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
self._current_tag_index: int = 0
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently buffered text.
|
||||
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
"""
|
||||
return self._text
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[str]:
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text while respecting tag boundaries.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete
|
||||
@@ -63,8 +63,9 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
text: New text to add to the buffer.
|
||||
|
||||
Returns:
|
||||
Processed text up to a sentence boundary (when not within tags),
|
||||
or None if more text is needed to complete a sentence or close tags.
|
||||
An Aggregation object containing text up to a sentence boundary and
|
||||
marked as SENTENCE type or None if more text is needed to complete a
|
||||
sentence or close tags.
|
||||
"""
|
||||
# Add new text to buffer
|
||||
self._text += text
|
||||
@@ -80,7 +81,7 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return result
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</test>!")
|
||||
@@ -45,14 +45,16 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# The exclamation point should be treated as a sentence boundary,
|
||||
# so the result should include just text up to and including "!"
|
||||
self.assertEqual(result, "Hello !")
|
||||
self.assertEqual(result.text, "Hello !")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
# Next sentence should be processed separately. Spaces around the sentence
|
||||
# should be stripped in the returned Aggregation.
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
self.assertEqual(result, " This is another sentence.")
|
||||
|
||||
self.assertEqual(result.text, "This is another sentence.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_incomplete_pattern(self):
|
||||
# Add text with incomplete pattern
|
||||
@@ -65,11 +67,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_not_called()
|
||||
|
||||
# Buffer should contain the incomplete text
|
||||
self.assertEqual(self.aggregator.text, "Hello <test>pattern content")
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
|
||||
|
||||
# Reset and confirm buffer is cleared
|
||||
await self.aggregator.reset()
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_multiple_patterns(self):
|
||||
# Set up multiple patterns and handlers
|
||||
@@ -106,10 +108,11 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(emphasis_match.content, "very")
|
||||
|
||||
# Voice pattern should be removed, emphasis pattern should remain
|
||||
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
|
||||
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_handle_interruption(self):
|
||||
# Start with incomplete pattern
|
||||
@@ -120,7 +123,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
await self.aggregator.handle_interruption()
|
||||
|
||||
# Buffer should be cleared
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
# Handler should not have been called
|
||||
self.test_handler.assert_not_called()
|
||||
@@ -141,7 +144,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
|
||||
|
||||
# Pattern should be removed, resulting in text with sentences merged
|
||||
self.assertEqual(result, "Hello Final sentence.")
|
||||
self.assertEqual(result.text, "Hello Final sentence.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
# Buffer should be empty
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
@@ -15,15 +15,21 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_reset_aggregations(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert self.aggregator.text == "Hello "
|
||||
assert self.aggregator.text.text == "Hello"
|
||||
await self.aggregator.reset()
|
||||
assert self.aggregator.text == ""
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_simple_sentence(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
|
||||
assert self.aggregator.text == ""
|
||||
aggregate = await self.aggregator.aggregate("Pipecat!")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert aggregate.type == "sentence"
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_multiple_sentences(self):
|
||||
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
|
||||
assert await self.aggregator.aggregate("you?") == " How are you?"
|
||||
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
# Aggregators should strip leading/trailing spaces when returning text
|
||||
assert self.aggregator.text.text == "How are"
|
||||
aggregate = await self.aggregator.aggregate("you?")
|
||||
assert aggregate.text == "How are you?"
|
||||
|
||||
@@ -18,16 +18,18 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# No tags involved, aggregate at end of sentence.
|
||||
result = await self.aggregator.aggregate("Hello Pipecat!")
|
||||
self.assertEqual(result, "Hello Pipecat!")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "Hello Pipecat!")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_basic_tags(self):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# Tags involved, avoid aggregation during tags.
|
||||
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_streaming_tags(self):
|
||||
await self.aggregator.reset()
|
||||
@@ -35,20 +37,22 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# Tags involved, stream small chunk of texts.
|
||||
result = await self.aggregator.aggregate("My email is <sp")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <sp")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <sp")
|
||||
|
||||
result = await self.aggregator.aggregate("ell>foo.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
|
||||
|
||||
result = await self.aggregator.aggregate("bar@pipecat.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
|
||||
|
||||
result = await self.aggregator.aggregate("ai</spe")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
result = await self.aggregator.aggregate("ll>.")
|
||||
self.assertEqual(result, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text, "")
|
||||
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
Reference in New Issue
Block a user