Merge pull request #1395 from pipecat-ai/aleix/multiple-text-filters-and-aggregators

TTSService: allow passing multiple text filters and aggregators
This commit is contained in:
Aleix Conchillo Flaqué
2025-03-18 21:25:29 -07:00
committed by GitHub
5 changed files with 61 additions and 20 deletions

View File

@@ -16,8 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new `BaseTextAggregator`. Text aggregators are used by the TTS service
to aggregate LLM tokens and decide when the aggregated text should be pushed
to the TTS service. It also allows for the text to be manipulated while it's
being aggregated.
to the TTS service. They also allow for the text to be manipulated while it's
being aggregated. Multiple text aggregators can be passed with
`text_aggregators` to the TTS service.
- Added new `UltravoxSTTService`.
(see https://github.com/fixie-ai/ultravox)
@@ -113,6 +114,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated the default mode for `CartesiaTTSService` and
`CartesiaHttpTTSService` to `sonic-2`.
### Deprecated
- `TTSService` parameter `text_filter` is now deprecated, use `text_filters`
instead which is now a list. This allows passing multiple filters that will be
executed in order.
### Removed
- Removed deprecated `audio.resample_audio()`, use `create_default_resampler()`

View File

@@ -60,7 +60,7 @@ async def main():
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
# text_filter=MarkdownTextFilter(),
# text_filters=[MarkdownTextFilter()],
)
llm = NimLLMService(

View File

@@ -119,7 +119,7 @@ async def main():
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id=VOICE_IDS["narrator"],
text_aggregator=pattern_aggregator,
text_aggregators=[pattern_aggregator],
)
# Initialize LLM

View File

@@ -97,7 +97,7 @@ async def main():
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
text_filter=MarkdownTextFilter(),
text_filters=[MarkdownTextFilter()],
)
llm = GoogleLLMService(

View File

@@ -8,7 +8,7 @@ import asyncio
import io
import wave
from abc import abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple, Type
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Tuple, Type
from loguru import logger
@@ -239,8 +239,9 @@ class TTSService(AIService):
# TTS output sample rate
sample_rate: Optional[int] = None,
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
text_aggregator: Optional[BaseTextAggregator] = None,
text_aggregators: Sequence[BaseTextAggregator] = [],
# Text filter executed after text has been aggregated.
text_filters: Sequence[BaseTextFilter] = [],
text_filter: Optional[BaseTextFilter] = None,
**kwargs,
):
@@ -256,8 +257,21 @@ class TTSService(AIService):
self._sample_rate = 0
self._voice_id: str = ""
self._settings: Dict[str, Any] = {}
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
self._text_filter: Optional[BaseTextFilter] = text_filter
# Ensure there's at least one text aggregator.
self._text_aggregators: Sequence[BaseTextAggregator] = text_aggregators or [
SimpleTextAggregator()
]
self._text_filters: Sequence[BaseTextFilter] = text_filters
if text_filter:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'text_filter' is deprecated, use 'text_filters' instead.",
DeprecationWarning,
)
self._text_filters = [text_filter]
self._stop_frame_task: Optional[asyncio.Task] = None
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
@@ -317,8 +331,9 @@ class TTSService(AIService):
self.set_model_name(value)
elif key == "voice":
self.set_voice(value)
elif key == "text_filter" and self._text_filter:
self._text_filter.update_settings(value)
elif key == "text_filter":
for filter in self._text_filters:
filter.update_settings(value)
else:
logger.warning(f"Unknown setting for TTS service: {key}")
@@ -343,8 +358,8 @@ class TTSService(AIService):
# pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
sentence = self._text_aggregator.text
self._text_aggregator.reset()
sentence = self._text_aggregators[-1].text
self._reset_aggregators()
self._processing_text = False
await self._push_tts_frames(sentence)
if isinstance(frame, LLMFullResponseEndFrame):
@@ -390,9 +405,10 @@ class TTSService(AIService):
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
self._processing_text = False
self._text_aggregator.handle_interruption()
if self._text_filter:
self._text_filter.handle_interruption()
for aggregator in self._text_aggregators:
aggregator.handle_interruption()
for filter in self._text_filters:
filter.handle_interruption()
async def _maybe_pause_frame_processing(self):
if self._processing_text and self._pause_frame_processing:
@@ -402,12 +418,25 @@ class TTSService(AIService):
if self._pause_frame_processing:
await self.resume_processing_frames()
def _reset_aggregators(self):
for aggregator in self._text_aggregators:
aggregator.reset()
async def _process_text_frame(self, frame: TextFrame):
text: Optional[str] = None
if not self._aggregate_sentences:
text = frame.text
else:
text = self._text_aggregator.aggregate(frame.text)
current_text = frame.text
# Process all aggregators except the last one.
for aggregator in self._text_aggregators[:-1]:
aggregator.aggregate(current_text)
current_text = aggregator.text
# The last aggregator decides whether we are sending text to the
# TTS or not.
text = self._text_aggregators[-1].aggregate(current_text)
if text:
await self._push_tts_frames(text)
@@ -427,11 +456,16 @@ class TTSService(AIService):
self._processing_text = True
await self.start_processing_metrics()
if self._text_filter:
self._text_filter.reset_interruption()
text = self._text_filter.filter(text)
# Process all filter.
for filter in self._text_filters:
filter.reset_interruption()
text = filter.filter(text)
await self.process_generator(self.run_tts(text))
await self.stop_processing_metrics()
if self._push_text_frames:
# We send the original text after the audio. This way, if we are
# interrupted, the text is not added to the assistant context.