diff --git a/CHANGELOG.md b/CHANGELOG.md index 7289ff4c5..95cc4d564 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 composing a best effort of the perceived llm output in a more digestable form and to do so whether or not it is processed by a TTS or if even a TTS exists. +- Introduced `LLMTextProcessor`: A new processor meant to allow customization for how + LLMTextFrames should be aggregated and considered. It's purpose is to turn + `LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService will still + aggregate `LLMTextFrame`s by sentence for the service to consume. However, if you + wish to override how the llm text is aggregated, you should no longer override the + TTS's internal text_aggregator, but instead, insert this processor between your LLM + and TTS in the pipeline. + ### Changed - ⚠️ Breaking change: `LLMContext.create_image_message()`, @@ -159,6 +167,11 @@ Croatian, Hungarian, Malay, Norwegian, Nynorsk, Slovak, Slovenian, Swedish, and - `english_normalization` input parameter for `MiniMaxHttpTTSService` is deprecated, use `test_normalization` instead. +- The TTS constructor field, `text_aggregator` is deprecated in favor of the new + `LLMTextProcessor`. TTSServices still have an internal aggregator for support of default + behavior, but if you want to override the aggregation behavior, you should use the new + processor. + ### Fixed - Fixed a `SimliVideoService` connection issue. diff --git a/src/pipecat/processors/aggregators/llm_text_processor.py b/src/pipecat/processors/aggregators/llm_text_processor.py new file mode 100644 index 000000000..44a8dc24e --- /dev/null +++ b/src/pipecat/processors/aggregators/llm_text_processor.py @@ -0,0 +1,106 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""LLM text processor module for processing and aggregating raw LLM output text. + +This processor will convert LLMTextFrames into AggregatedTextFrames based on the +configured text aggregator. Using the customizable aggregator, it provides +functionality to handle or manipulate LLM text frames before they are sent to other +components such as TTS services or context aggregators. It can be used to pre-aggregate +and categorize, modify, or filter direct output tokens from the LLM. +""" + +from typing import Optional + +from pipecat.frames.frames import ( + AggregatedTextFrame, + EndFrame, + Frame, + InterruptionFrame, + LLMFullResponseEndFrame, + LLMTextFrame, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.utils.text.base_text_aggregator import BaseTextAggregator +from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator + + +class LLMTextProcessor(FrameProcessor): + """A processor for handling or manipulating LLM text frames before they are processed further. + + This processor will convert LLMTextFrames into AggregatedTextFrames based on the configured + text aggregator. Using the customizable aggregator, it provides functionality to handle or + manipulate LLM text frames before they are sent to other components such as TTS services or + context aggregators. It can be used to pre-aggregate and categorize, modify, or filter direct + output tokens from the LLM. + """ + + def __init__(self, *, text_aggregator: Optional[BaseTextAggregator] = None, **kwargs): + """Initialize the LLM text processor. + + Args: + text_aggregator: An optional text aggregator to use for processing LLM text frames. By + default, a SimpleTextAggregator aggregating by sentence will be used. + **kwargs: Additional arguments passed to parent class. + + TODO: Allow transformations per aggregation type or all (and deprecate the TTS filters). + """ + super().__init__(**kwargs) + self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process an LLMTextFrames using the aggregator to generate AggregatedTextFrames. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ + await super().process_frame(frame, direction) + + if isinstance(frame, InterruptionFrame): + await self._handle_interruption(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, LLMTextFrame): + await self._handle_llm_text(frame) + elif isinstance(frame, LLMFullResponseEndFrame): + await self._handle_llm_end(frame.skip_tts) + await self.push_frame(frame, direction) + elif isinstance(frame, EndFrame): + await self._handle_llm_end() + await self.push_frame(frame, direction) + else: + await self.push_frame(frame, direction) + + async def _handle_interruption(self, _): + """Handle interruptions by resetting the text aggregator.""" + await self._text_aggregator.handle_interruption() + + async def reset(self): + """Reset the internal state of the text processor and its aggregator.""" + await self._text_aggregator.reset() + + async def _handle_llm_text(self, in_frame: LLMTextFrame): + aggregation = await self._text_aggregator.aggregate(in_frame.text) + if aggregation: + out_frame = AggregatedTextFrame( + text=aggregation.text, + aggregated_by=aggregation.type, + ) + out_frame.skip_tts = in_frame.skip_tts + await self.push_frame(out_frame) + + async def _handle_llm_end(self, skip_tts: bool = False): + # Flush any remaining aggregated text at the end of the LLM response + aggregation = self._text_aggregator.text + await self._text_aggregator.reset() + text = aggregation.text.strip() + if text: + out_frame = AggregatedTextFrame( + text=text, + aggregated_by=aggregation.type, + ) + out_frame.skip_tts = skip_tts + await self.push_frame(out_frame) diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index a4f8933ef..2267ee29f 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -122,6 +122,10 @@ class TTSService(AIService): pause_frame_processing: Whether to pause frame processing during audio generation. sample_rate: Output sample rate for generated audio. text_aggregator: Custom text aggregator for processing incoming text. + + .. deprecated:: 0.0.95 + Use an LLMTextProcessor before the TTSService for custom text aggregation. + text_filters: Sequence of text filters to apply after aggregation. text_filter: Single text filter (deprecated, use text_filters). @@ -144,6 +148,16 @@ class TTSService(AIService): self._voice_id: str = "" self._settings: Dict[str, Any] = {} self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator() + if text_aggregator: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Parameter 'text_aggregator' is deprecated. Use an LLMTextProcessor before the TTSService for custom text aggregation.", + DeprecationWarning, + ) + self._text_filters: Sequence[BaseTextFilter] = text_filters or [] self._transport_destination: Optional[str] = transport_destination self._tracing_enabled: bool = False @@ -338,6 +352,8 @@ class TTSService(AIService): and frame.skip_tts ): await self.push_frame(frame, direction) + elif isinstance(frame, AggregatedTextFrame): + await self._push_tts_frames(frame) elif ( isinstance(frame, TextFrame) and not isinstance(frame, InterimTranscriptionFrame) @@ -482,8 +498,11 @@ class TTSService(AIService): async def _push_tts_frames( self, src_frame: AggregatedTextFrame, includes_inter_frame_spaces: Optional[bool] = False ): + type = src_frame.aggregated_by + text = src_frame.text + # Remove leading newlines only - text = src_frame.text.lstrip("\n") + text = text.lstrip("\n") # Don't send only whitespace. This causes problems for some TTS models. But also don't # strip all whitespace, as whitespace can influence prosody. @@ -497,20 +516,35 @@ class TTSService(AIService): await self.start_processing_metrics() - # Process all filter. + # Process all filters. for filter in self._text_filters: await filter.reset_interruption() text = await filter.filter(text) - if text: - await self.process_generator(self.run_tts(text)) + if not text.strip(): + await self.stop_processing_metrics() + return + + # To support use cases that may want to know the text before it's spoken, we + # push the AggregatedTextFrame version before transforming and sending to TTS. + # However, we do not want to add this text to the assistant context until it + # is spoken, so we set append_to_context to False. + src_frame.append_to_context = False + await self.push_frame(src_frame) + + await self.process_generator(self.run_tts(text)) await self.stop_processing_metrics() if self._push_text_frames: - # We send the original text after the audio. This way, if we are - # interrupted, the text is not added to the assistant context. - frame = TTSTextFrame(text, aggregated_by=src_frame.aggregated_by) + # In TTS services that support word timestamps, the TTSTextFrames + # are pushed as words are spoken. However, in the case where the TTS service + # does not support word timestamps (i.e. _push_text_frames is True), we send + # the original (non-transformed) text after the TTS generation has completed. + # This way, if we are interrupted, the text is not added to the assistant + # context and the context that IS added does not include TTS-specific tags + # or transformations. + frame = TTSTextFrame(text, aggregated_by=type) frame.includes_inter_frame_spaces = includes_inter_frame_spaces await self.push_frame(frame) diff --git a/tests/test_piper_tts.py b/tests/test_piper_tts.py index 75893f93f..a006f555c 100644 --- a/tests/test_piper_tts.py +++ b/tests/test_piper_tts.py @@ -13,6 +13,7 @@ import pytest from aiohttp import web from pipecat.frames.frames import ( + AggregatedTextFrame, ErrorFrame, TTSAudioRawFrame, TTSSpeakFrame, @@ -74,6 +75,7 @@ async def test_run_piper_tts_success(aiohttp_client): ] expected_returned_frames = [ + AggregatedTextFrame, TTSStartedFrame, TTSAudioRawFrame, TTSAudioRawFrame, @@ -121,7 +123,7 @@ async def test_run_piper_tts_error(aiohttp_client): TTSSpeakFrame(text="Error case."), ] - expected_down_frames = [TTSStoppedFrame, TTSTextFrame] + expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame] expected_up_frames = [ErrorFrame]