diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c29b8e98..f90dd80dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,8 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added new `BaseTextAggregator`. Text aggregators are used by the TTS service to aggregate LLM tokens and decide when the aggregated text should be pushed - to the TTS service. It also allows for the text to be manipulated while it's - being aggregated. + to the TTS service. They also allow for the text to be manipulated while it's + being aggregated. Multiple text aggregators can be passed with + `text_aggregators` to the TTS service. - Added new `UltravoxSTTService`. (see https://github.com/fixie-ai/ultravox) @@ -113,6 +114,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Updated the default mode for `CartesiaTTSService` and `CartesiaHttpTTSService` to `sonic-2`. +### Deprecated + +- `TTSService` parameter `text_filter` is now deprecated, use `text_filters` + instead which is now a list. This allows passing multiple filters that will be + executed in order. + ### Removed - Removed deprecated `audio.resample_audio()`, use `create_default_resampler()` diff --git a/examples/foundational/14j-function-calling-nim.py b/examples/foundational/14j-function-calling-nim.py index d703d637a..ea8e25cf6 100644 --- a/examples/foundational/14j-function-calling-nim.py +++ b/examples/foundational/14j-function-calling-nim.py @@ -60,7 +60,7 @@ async def main(): tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady - # text_filter=MarkdownTextFilter(), + # text_filters=[MarkdownTextFilter()], ) llm = NimLLMService( diff --git a/examples/foundational/35-pattern-pair-voice-switching.py b/examples/foundational/35-pattern-pair-voice-switching.py index bb9587706..7d0094132 100644 --- a/examples/foundational/35-pattern-pair-voice-switching.py +++ b/examples/foundational/35-pattern-pair-voice-switching.py @@ -119,7 +119,7 @@ async def main(): tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id=VOICE_IDS["narrator"], - text_aggregator=pattern_aggregator, + text_aggregators=[pattern_aggregator], ) # Initialize LLM diff --git a/examples/news-chatbot/server/news_bot.py b/examples/news-chatbot/server/news_bot.py index 2a389094e..b9f60200f 100644 --- a/examples/news-chatbot/server/news_bot.py +++ b/examples/news-chatbot/server/news_bot.py @@ -97,7 +97,7 @@ async def main(): tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady - text_filter=MarkdownTextFilter(), + text_filters=[MarkdownTextFilter()], ) llm = GoogleLLMService( diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 3fe33d69e..904e5cf90 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -8,7 +8,7 @@ import asyncio import io import wave from abc import abstractmethod -from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple, Type +from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Tuple, Type from loguru import logger @@ -239,8 +239,9 @@ class TTSService(AIService): # TTS output sample rate sample_rate: Optional[int] = None, # Text aggregator to aggregate incoming tokens and decide when to push to the TTS. - text_aggregator: Optional[BaseTextAggregator] = None, + text_aggregators: Sequence[BaseTextAggregator] = [], # Text filter executed after text has been aggregated. + text_filters: Sequence[BaseTextFilter] = [], text_filter: Optional[BaseTextFilter] = None, **kwargs, ): @@ -256,8 +257,21 @@ class TTSService(AIService): self._sample_rate = 0 self._voice_id: str = "" self._settings: Dict[str, Any] = {} - self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator() - self._text_filter: Optional[BaseTextFilter] = text_filter + # Ensure there's at least one text aggregator. + self._text_aggregators: Sequence[BaseTextAggregator] = text_aggregators or [ + SimpleTextAggregator() + ] + self._text_filters: Sequence[BaseTextFilter] = text_filters + if text_filter: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Parameter 'text_filter' is deprecated, use 'text_filters' instead.", + DeprecationWarning, + ) + self._text_filters = [text_filter] self._stop_frame_task: Optional[asyncio.Task] = None self._stop_frame_queue: asyncio.Queue = asyncio.Queue() @@ -317,8 +331,9 @@ class TTSService(AIService): self.set_model_name(value) elif key == "voice": self.set_voice(value) - elif key == "text_filter" and self._text_filter: - self._text_filter.update_settings(value) + elif key == "text_filter": + for filter in self._text_filters: + filter.update_settings(value) else: logger.warning(f"Unknown setting for TTS service: {key}") @@ -343,8 +358,8 @@ class TTSService(AIService): # pause to avoid audio overlapping. await self._maybe_pause_frame_processing() - sentence = self._text_aggregator.text - self._text_aggregator.reset() + sentence = self._text_aggregators[-1].text + self._reset_aggregators() self._processing_text = False await self._push_tts_frames(sentence) if isinstance(frame, LLMFullResponseEndFrame): @@ -390,9 +405,10 @@ class TTSService(AIService): async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): self._processing_text = False - self._text_aggregator.handle_interruption() - if self._text_filter: - self._text_filter.handle_interruption() + for aggregator in self._text_aggregators: + aggregator.handle_interruption() + for filter in self._text_filters: + filter.handle_interruption() async def _maybe_pause_frame_processing(self): if self._processing_text and self._pause_frame_processing: @@ -402,12 +418,25 @@ class TTSService(AIService): if self._pause_frame_processing: await self.resume_processing_frames() + def _reset_aggregators(self): + for aggregator in self._text_aggregators: + aggregator.reset() + async def _process_text_frame(self, frame: TextFrame): text: Optional[str] = None if not self._aggregate_sentences: text = frame.text else: - text = self._text_aggregator.aggregate(frame.text) + current_text = frame.text + + # Process all aggregators except the last one. + for aggregator in self._text_aggregators[:-1]: + aggregator.aggregate(current_text) + current_text = aggregator.text + + # The last aggregator decides whether we are sending text to the + # TTS or not. + text = self._text_aggregators[-1].aggregate(current_text) if text: await self._push_tts_frames(text) @@ -427,11 +456,16 @@ class TTSService(AIService): self._processing_text = True await self.start_processing_metrics() - if self._text_filter: - self._text_filter.reset_interruption() - text = self._text_filter.filter(text) + + # Process all filter. + for filter in self._text_filters: + filter.reset_interruption() + text = filter.filter(text) + await self.process_generator(self.run_tts(text)) + await self.stop_processing_metrics() + if self._push_text_frames: # We send the original text after the audio. This way, if we are # interrupted, the text is not added to the assistant context.