From ecbc41045c3241bdfb043bdf2f51691cba3e2088 Mon Sep 17 00:00:00 2001 From: mattie ruth backman Date: Mon, 17 Nov 2025 18:14:45 -0500 Subject: [PATCH] Added ability to transform text just-in-time before it gets sent to the TTS --- CHANGELOG.md | 6 +++ src/pipecat/services/tts_service.py | 65 ++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89f6d38e3..07d922632 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -162,6 +162,12 @@ Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian, Swedish, and - `TTSService` base class updates: - `TTSService`s now accept a new `skip_aggregator_types` to avoid speaking certain aggregation types (now determined/returned by the aggregator) + - Introduced the ability to do a just-in-time transform of text before it gets sent to the + TTS service via callbacks you can set up via a new init field, `text_transforms` or a new + method `add_text_transformer()`. This makes it possible to do things like introduce + TTS-specific tags for spelling or emotion or change the pronunciation of something on the + fly. `remove_text_transformer` has also been added to support removing a registered + transform callback. ### Deprecated diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index 637077f10..8d7c4e6bb 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -12,6 +12,8 @@ from typing import ( Any, AsyncGenerator, AsyncIterator, + Awaitable, + Callable, Dict, List, Mapping, @@ -105,6 +107,14 @@ class TTSService(AIService): text_aggregator: Optional[BaseTextAggregator] = None, # Types of text aggregations that should not be spoken. skip_aggregator_types: Optional[List[str]] = [], + # A list of callables to transform text before just before sending it to TTS. + # Each callable takes the aggregated text and its type, and returns the transformed text. + # To register, provide a list of tuples of (aggregation_type | '*', transform_function). + text_transforms: Optional[ + List[ + Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]] + ] + ] = None, # Text filter executed after text has been aggregated. text_filters: Optional[Sequence[BaseTextFilter]] = None, text_filter: Optional[BaseTextFilter] = None, @@ -123,12 +133,17 @@ class TTSService(AIService): silence_time_s: Duration of silence to push when push_silence_after_stop is True. pause_frame_processing: Whether to pause frame processing during audio generation. sample_rate: Output sample rate for generated audio. - skip_aggregator_types: List of aggregation types that should not be spoken. text_aggregator: Custom text aggregator for processing incoming text. .. deprecated:: 0.0.95 Use an LLMTextProcessor before the TTSService for custom text aggregation. + skip_aggregator_types: List of aggregation types that should not be spoken. + text_transforms: A list of callables to transform text before just before sending it + to TTS. Each callable takes the aggregated text and its type, and returns the + transformed text. To register, provide a list of tuples of + (aggregation_type | '*', transform_function). + text_filters: Sequence of text filters to apply after aggregation. text_filter: Single text filter (deprecated, use text_filters). @@ -162,6 +177,10 @@ class TTSService(AIService): ) self._skip_aggregator_types: List[str] = skip_aggregator_types or [] + self._text_transforms: List[ + Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]] + ] = text_transforms or [] + # TODO: Deprecate _text_filters when added to LLMTextProcessor self._text_filters: Sequence[BaseTextFilter] = text_filters or [] self._transport_destination: Optional[str] = transport_destination self._tracing_enabled: bool = False @@ -301,6 +320,39 @@ class TTSService(AIService): await self.cancel_task(self._stop_frame_task) self._stop_frame_task = None + def add_text_transformer( + self, + transform_function: Callable[[str, AggregationType | str], Awaitable[str]], + aggregation_type: AggregationType | str = "*", + ): + """Transform text for a specific aggregation type. + + Args: + transform_function: The function to apply for transformation. This function should take + the text and aggregation type as input and return the transformed text. + Ex.: async def my_transform(text: str, aggregation_type: str) -> str: + aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating + the function should handle all text before sending to TTS. + """ + self._text_transforms.append((aggregation_type, transform_function)) + + def remove_text_transformer( + self, + transform_function: Callable[[str, AggregationType | str], Awaitable[str]], + aggregation_type: AggregationType | str = "*", + ): + """Remove a text transformer for a specific aggregation type. + + Args: + transform_function: The function to remove. + aggregation_type: The type of aggregation to remove the transformer for. + """ + self._text_transforms = [ + (agg_type, func) + for agg_type, func in self._text_transforms + if not (agg_type == aggregation_type and func == transform_function) + ] + async def _update_settings(self, settings: Mapping[str, Any]): for key, value in settings.items(): if key in self._settings: @@ -542,7 +594,16 @@ class TTSService(AIService): src_frame.append_to_context = False await self.push_frame(src_frame) - await self.process_generator(self.run_tts(text)) + # Note: Text transformations are meant to only affect the text sent to the TTS for + # TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting + # TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS + # services that support word-level timestamps, this CAN affect the resulting context + # since the TTSTextFrames are generated from the TTS output stream + transformed_text = text + for aggregation_type, transform in self._text_transforms: + if aggregation_type == type or aggregation_type == "*": + transformed_text = await transform(transformed_text, type) + await self.process_generator(self.run_tts(transformed_text)) await self.stop_processing_metrics()