Added ability to transform text just-in-time before it gets sent to the TTS
This commit is contained in:
committed by
Mattie Ruth
parent
e1528d0f0c
commit
ecbc41045c
@@ -162,6 +162,12 @@ Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian, Swedish, and
|
||||
- `TTSService` base class updates:
|
||||
- `TTSService`s now accept a new `skip_aggregator_types` to avoid speaking certain aggregation
|
||||
types (now determined/returned by the aggregator)
|
||||
- Introduced the ability to do a just-in-time transform of text before it gets sent to the
|
||||
TTS service via callbacks you can set up via a new init field, `text_transforms` or a new
|
||||
method `add_text_transformer()`. This makes it possible to do things like introduce
|
||||
TTS-specific tags for spelling or emotion or change the pronunciation of something on the
|
||||
fly. `remove_text_transformer` has also been added to support removing a registered
|
||||
transform callback.
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
@@ -105,6 +107,14 @@ class TTSService(AIService):
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
# Types of text aggregations that should not be spoken.
|
||||
skip_aggregator_types: Optional[List[str]] = [],
|
||||
# A list of callables to transform text before just before sending it to TTS.
|
||||
# Each callable takes the aggregated text and its type, and returns the transformed text.
|
||||
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
|
||||
text_transforms: Optional[
|
||||
List[
|
||||
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
|
||||
]
|
||||
] = None,
|
||||
# Text filter executed after text has been aggregated.
|
||||
text_filters: Optional[Sequence[BaseTextFilter]] = None,
|
||||
text_filter: Optional[BaseTextFilter] = None,
|
||||
@@ -123,12 +133,17 @@ class TTSService(AIService):
|
||||
silence_time_s: Duration of silence to push when push_silence_after_stop is True.
|
||||
pause_frame_processing: Whether to pause frame processing during audio generation.
|
||||
sample_rate: Output sample rate for generated audio.
|
||||
skip_aggregator_types: List of aggregation types that should not be spoken.
|
||||
text_aggregator: Custom text aggregator for processing incoming text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
skip_aggregator_types: List of aggregation types that should not be spoken.
|
||||
text_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
transformed text. To register, provide a list of tuples of
|
||||
(aggregation_type | '*', transform_function).
|
||||
|
||||
text_filters: Sequence of text filters to apply after aggregation.
|
||||
text_filter: Single text filter (deprecated, use text_filters).
|
||||
|
||||
@@ -162,6 +177,10 @@ class TTSService(AIService):
|
||||
)
|
||||
|
||||
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
|
||||
self._text_transforms: List[
|
||||
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
|
||||
] = text_transforms or []
|
||||
# TODO: Deprecate _text_filters when added to LLMTextProcessor
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
self._tracing_enabled: bool = False
|
||||
@@ -301,6 +320,39 @@ class TTSService(AIService):
|
||||
await self.cancel_task(self._stop_frame_task)
|
||||
self._stop_frame_task = None
|
||||
|
||||
def add_text_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Transform text for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to apply for transformation. This function should take
|
||||
the text and aggregation type as input and return the transformed text.
|
||||
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
|
||||
aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating
|
||||
the function should handle all text before sending to TTS.
|
||||
"""
|
||||
self._text_transforms.append((aggregation_type, transform_function))
|
||||
|
||||
def remove_text_transformer(
|
||||
self,
|
||||
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
|
||||
aggregation_type: AggregationType | str = "*",
|
||||
):
|
||||
"""Remove a text transformer for a specific aggregation type.
|
||||
|
||||
Args:
|
||||
transform_function: The function to remove.
|
||||
aggregation_type: The type of aggregation to remove the transformer for.
|
||||
"""
|
||||
self._text_transforms = [
|
||||
(agg_type, func)
|
||||
for agg_type, func in self._text_transforms
|
||||
if not (agg_type == aggregation_type and func == transform_function)
|
||||
]
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
@@ -542,7 +594,16 @@ class TTSService(AIService):
|
||||
src_frame.append_to_context = False
|
||||
await self.push_frame(src_frame)
|
||||
|
||||
await self.process_generator(self.run_tts(text))
|
||||
# Note: Text transformations are meant to only affect the text sent to the TTS for
|
||||
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
|
||||
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
|
||||
# services that support word-level timestamps, this CAN affect the resulting context
|
||||
# since the TTSTextFrames are generated from the TTS output stream
|
||||
transformed_text = text
|
||||
for aggregation_type, transform in self._text_transforms:
|
||||
if aggregation_type == type or aggregation_type == "*":
|
||||
transformed_text = await transform(transformed_text, type)
|
||||
await self.process_generator(self.run_tts(transformed_text))
|
||||
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user