Added ability to transform text just-in-time before it gets sent to the TTS

This commit is contained in:
mattie ruth backman
2025-11-17 18:14:45 -05:00
committed by Mattie Ruth
parent e1528d0f0c
commit ecbc41045c
2 changed files with 69 additions and 2 deletions

View File

@@ -162,6 +162,12 @@ Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian, Swedish, and
- `TTSService` base class updates:
- `TTSService`s now accept a new `skip_aggregator_types` to avoid speaking certain aggregation
types (now determined/returned by the aggregator)
- Introduced the ability to do a just-in-time transform of text before it gets sent to the
TTS service via callbacks you can set up via a new init field, `text_transforms` or a new
method `add_text_transformer()`. This makes it possible to do things like introduce
TTS-specific tags for spelling or emotion or change the pronunciation of something on the
fly. `remove_text_transformer` has also been added to support removing a registered
transform callback.
### Deprecated

View File

@@ -12,6 +12,8 @@ from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Mapping,
@@ -105,6 +107,14 @@ class TTSService(AIService):
text_aggregator: Optional[BaseTextAggregator] = None,
# Types of text aggregations that should not be spoken.
skip_aggregator_types: Optional[List[str]] = [],
# A list of callables to transform text before just before sending it to TTS.
# Each callable takes the aggregated text and its type, and returns the transformed text.
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
text_transforms: Optional[
List[
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
]
] = None,
# Text filter executed after text has been aggregated.
text_filters: Optional[Sequence[BaseTextFilter]] = None,
text_filter: Optional[BaseTextFilter] = None,
@@ -123,12 +133,17 @@ class TTSService(AIService):
silence_time_s: Duration of silence to push when push_silence_after_stop is True.
pause_frame_processing: Whether to pause frame processing during audio generation.
sample_rate: Output sample rate for generated audio.
skip_aggregator_types: List of aggregation types that should not be spoken.
text_aggregator: Custom text aggregator for processing incoming text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
skip_aggregator_types: List of aggregation types that should not be spoken.
text_transforms: A list of callables to transform text before just before sending it
to TTS. Each callable takes the aggregated text and its type, and returns the
transformed text. To register, provide a list of tuples of
(aggregation_type | '*', transform_function).
text_filters: Sequence of text filters to apply after aggregation.
text_filter: Single text filter (deprecated, use text_filters).
@@ -162,6 +177,10 @@ class TTSService(AIService):
)
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
self._text_transforms: List[
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
] = text_transforms or []
# TODO: Deprecate _text_filters when added to LLMTextProcessor
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
self._transport_destination: Optional[str] = transport_destination
self._tracing_enabled: bool = False
@@ -301,6 +320,39 @@ class TTSService(AIService):
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
def add_text_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Transform text for a specific aggregation type.
Args:
transform_function: The function to apply for transformation. This function should take
the text and aggregation type as input and return the transformed text.
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating
the function should handle all text before sending to TTS.
"""
self._text_transforms.append((aggregation_type, transform_function))
def remove_text_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Remove a text transformer for a specific aggregation type.
Args:
transform_function: The function to remove.
aggregation_type: The type of aggregation to remove the transformer for.
"""
self._text_transforms = [
(agg_type, func)
for agg_type, func in self._text_transforms
if not (agg_type == aggregation_type and func == transform_function)
]
async def _update_settings(self, settings: Mapping[str, Any]):
for key, value in settings.items():
if key in self._settings:
@@ -542,7 +594,16 @@ class TTSService(AIService):
src_frame.append_to_context = False
await self.push_frame(src_frame)
await self.process_generator(self.run_tts(text))
# Note: Text transformations are meant to only affect the text sent to the TTS for
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
# services that support word-level timestamps, this CAN affect the resulting context
# since the TTSTextFrames are generated from the TTS output stream
transformed_text = text
for aggregation_type, transform in self._text_transforms:
if aggregation_type == type or aggregation_type == "*":
transformed_text = await transform(transformed_text, type)
await self.process_generator(self.run_tts(transformed_text))
await self.stop_processing_metrics()