Merge pull request #1395 from pipecat-ai/aleix/multiple-text-filters-and-aggregators
TTSService: allow passing multiple text filters and aggregators
This commit is contained in:
11
CHANGELOG.md
11
CHANGELOG.md
@@ -16,8 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Added new `BaseTextAggregator`. Text aggregators are used by the TTS service
|
||||
to aggregate LLM tokens and decide when the aggregated text should be pushed
|
||||
to the TTS service. It also allows for the text to be manipulated while it's
|
||||
being aggregated.
|
||||
to the TTS service. They also allow for the text to be manipulated while it's
|
||||
being aggregated. Multiple text aggregators can be passed with
|
||||
`text_aggregators` to the TTS service.
|
||||
|
||||
- Added new `UltravoxSTTService`.
|
||||
(see https://github.com/fixie-ai/ultravox)
|
||||
@@ -113,6 +114,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Updated the default mode for `CartesiaTTSService` and
|
||||
`CartesiaHttpTTSService` to `sonic-2`.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `TTSService` parameter `text_filter` is now deprecated, use `text_filters`
|
||||
instead which is now a list. This allows passing multiple filters that will be
|
||||
executed in order.
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated `audio.resample_audio()`, use `create_default_resampler()`
|
||||
|
||||
@@ -60,7 +60,7 @@ async def main():
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
# text_filter=MarkdownTextFilter(),
|
||||
# text_filters=[MarkdownTextFilter()],
|
||||
)
|
||||
|
||||
llm = NimLLMService(
|
||||
|
||||
@@ -119,7 +119,7 @@ async def main():
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id=VOICE_IDS["narrator"],
|
||||
text_aggregator=pattern_aggregator,
|
||||
text_aggregators=[pattern_aggregator],
|
||||
)
|
||||
|
||||
# Initialize LLM
|
||||
|
||||
@@ -97,7 +97,7 @@ async def main():
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
text_filter=MarkdownTextFilter(),
|
||||
text_filters=[MarkdownTextFilter()],
|
||||
)
|
||||
|
||||
llm = GoogleLLMService(
|
||||
|
||||
@@ -8,7 +8,7 @@ import asyncio
|
||||
import io
|
||||
import wave
|
||||
from abc import abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple, Type
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Tuple, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -239,8 +239,9 @@ class TTSService(AIService):
|
||||
# TTS output sample rate
|
||||
sample_rate: Optional[int] = None,
|
||||
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
text_aggregators: Sequence[BaseTextAggregator] = [],
|
||||
# Text filter executed after text has been aggregated.
|
||||
text_filters: Sequence[BaseTextFilter] = [],
|
||||
text_filter: Optional[BaseTextFilter] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -256,8 +257,21 @@ class TTSService(AIService):
|
||||
self._sample_rate = 0
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
self._text_filter: Optional[BaseTextFilter] = text_filter
|
||||
# Ensure there's at least one text aggregator.
|
||||
self._text_aggregators: Sequence[BaseTextAggregator] = text_aggregators or [
|
||||
SimpleTextAggregator()
|
||||
]
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters
|
||||
if text_filter:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'text_filter' is deprecated, use 'text_filters' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self._text_filters = [text_filter]
|
||||
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
@@ -317,8 +331,9 @@ class TTSService(AIService):
|
||||
self.set_model_name(value)
|
||||
elif key == "voice":
|
||||
self.set_voice(value)
|
||||
elif key == "text_filter" and self._text_filter:
|
||||
self._text_filter.update_settings(value)
|
||||
elif key == "text_filter":
|
||||
for filter in self._text_filters:
|
||||
filter.update_settings(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for TTS service: {key}")
|
||||
|
||||
@@ -343,8 +358,8 @@ class TTSService(AIService):
|
||||
# pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
|
||||
sentence = self._text_aggregator.text
|
||||
self._text_aggregator.reset()
|
||||
sentence = self._text_aggregators[-1].text
|
||||
self._reset_aggregators()
|
||||
self._processing_text = False
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
@@ -390,9 +405,10 @@ class TTSService(AIService):
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
self._processing_text = False
|
||||
self._text_aggregator.handle_interruption()
|
||||
if self._text_filter:
|
||||
self._text_filter.handle_interruption()
|
||||
for aggregator in self._text_aggregators:
|
||||
aggregator.handle_interruption()
|
||||
for filter in self._text_filters:
|
||||
filter.handle_interruption()
|
||||
|
||||
async def _maybe_pause_frame_processing(self):
|
||||
if self._processing_text and self._pause_frame_processing:
|
||||
@@ -402,12 +418,25 @@ class TTSService(AIService):
|
||||
if self._pause_frame_processing:
|
||||
await self.resume_processing_frames()
|
||||
|
||||
def _reset_aggregators(self):
|
||||
for aggregator in self._text_aggregators:
|
||||
aggregator.reset()
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: Optional[str] = None
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
else:
|
||||
text = self._text_aggregator.aggregate(frame.text)
|
||||
current_text = frame.text
|
||||
|
||||
# Process all aggregators except the last one.
|
||||
for aggregator in self._text_aggregators[:-1]:
|
||||
aggregator.aggregate(current_text)
|
||||
current_text = aggregator.text
|
||||
|
||||
# The last aggregator decides whether we are sending text to the
|
||||
# TTS or not.
|
||||
text = self._text_aggregators[-1].aggregate(current_text)
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
@@ -427,11 +456,16 @@ class TTSService(AIService):
|
||||
self._processing_text = True
|
||||
|
||||
await self.start_processing_metrics()
|
||||
if self._text_filter:
|
||||
self._text_filter.reset_interruption()
|
||||
text = self._text_filter.filter(text)
|
||||
|
||||
# Process all filter.
|
||||
for filter in self._text_filters:
|
||||
filter.reset_interruption()
|
||||
text = filter.filter(text)
|
||||
|
||||
await self.process_generator(self.run_tts(text))
|
||||
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
if self._push_text_frames:
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
|
||||
Reference in New Issue
Block a user