diff --git a/changelog/3696.added.md b/changelog/3696.added.md new file mode 100644 index 000000000..39726d930 --- /dev/null +++ b/changelog/3696.added.md @@ -0,0 +1 @@ +- Added `TextAggregationMetricsData` metric measuring the time from the first LLM token to the first complete sentence, representing the latency cost of sentence aggregation in the TTS pipeline. diff --git a/changelog/3696.changed.md b/changelog/3696.changed.md new file mode 100644 index 000000000..a495560ba --- /dev/null +++ b/changelog/3696.changed.md @@ -0,0 +1 @@ +- Added `text_aggregation_mode` parameter to `TTSService` and all TTS subclasses with a new `TextAggregationMode` enum (`SENTENCE`, `TOKEN`). All text now flows through text aggregators regardless of mode, enabling pattern detection and tag handling in TOKEN mode. diff --git a/changelog/3696.deprecated.md b/changelog/3696.deprecated.md new file mode 100644 index 000000000..7b371fc21 --- /dev/null +++ b/changelog/3696.deprecated.md @@ -0,0 +1 @@ +- ⚠️ Deprecated `aggregate_sentences` parameter on `TTSService` and all TTS subclasses. Use `text_aggregation_mode=TextAggregationMode.SENTENCE` or `text_aggregation_mode=TextAggregationMode.TOKEN` instead. diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index c5964506a..074e091ea 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -24,6 +24,7 @@ from pipecat.runner.utils import create_transport from pipecat.services.cartesia.tts import CartesiaTTSService from pipecat.services.deepgram.stt import DeepgramSTTService from pipecat.services.openai.llm import OpenAILLMService +from pipecat.services.tts_service import TextAggregationMode from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams @@ -56,6 +57,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + # Alternatively, you can use TextAggregationMode.TOKEN to stream tokens instead of + # sentencesfor faster response times. + # text_aggregation_mode=TextAggregationMode.TOKEN, ) llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index d359bcfb1..55ae975d1 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -14,7 +14,6 @@ and LLM processing. import asyncio import time from dataclasses import dataclass, field -from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -36,6 +35,7 @@ from pipecat.audio.turn.base_turn_analyzer import BaseTurnParams from pipecat.audio.vad.vad_analyzer import VADParams from pipecat.metrics.metrics import MetricsData from pipecat.transcriptions.language import Language +from pipecat.utils.text.base_text_aggregator import AggregationType from pipecat.utils.time import nanoseconds_to_str from pipecat.utils.utils import obj_count, obj_id @@ -393,16 +393,6 @@ class LLMTextFrame(TextFrame): self.includes_inter_frame_spaces = True -class AggregationType(str, Enum): - """Built-in aggregation strings.""" - - SENTENCE = "sentence" - WORD = "word" - - def __str__(self): - return self.value - - @dataclass class AggregatedTextFrame(TextFrame): """Text frame representing an aggregation of TextFrames. diff --git a/src/pipecat/metrics/metrics.py b/src/pipecat/metrics/metrics.py index ccf30227a..2030306e5 100644 --- a/src/pipecat/metrics/metrics.py +++ b/src/pipecat/metrics/metrics.py @@ -87,6 +87,19 @@ class TTSUsageMetricsData(MetricsData): value: int +class TextAggregationMetricsData(MetricsData): + """Text aggregation time metrics data. + + Measures the time from the first LLM token to the first complete sentence, + representing the latency cost of sentence aggregation in the TTS pipeline. + + Parameters: + value: Aggregation time in seconds. + """ + + value: float + + class TurnMetricsData(MetricsData): """Metrics data for turn detection predictions. diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index bcdb2d57b..baa52cc70 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -485,10 +485,23 @@ class FrameProcessor(BaseObject): if frame: await self.push_frame(frame) + async def start_text_aggregation_metrics(self): + """Start text aggregation time metrics collection.""" + if self.can_generate_metrics() and self.metrics_enabled: + await self._metrics.start_text_aggregation_metrics() + + async def stop_text_aggregation_metrics(self): + """Stop text aggregation time metrics collection and push results.""" + if self.can_generate_metrics() and self.metrics_enabled: + frame = await self._metrics.stop_text_aggregation_metrics() + if frame: + await self.push_frame(frame) + async def stop_all_metrics(self): """Stop all active metrics collection.""" await self.stop_ttfb_metrics() await self.stop_processing_metrics() + await self.stop_text_aggregation_metrics() def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task: """Create a new task managed by this processor. diff --git a/src/pipecat/processors/metrics/frame_processor_metrics.py b/src/pipecat/processors/metrics/frame_processor_metrics.py index c82fd9698..7a52895a2 100644 --- a/src/pipecat/processors/metrics/frame_processor_metrics.py +++ b/src/pipecat/processors/metrics/frame_processor_metrics.py @@ -17,6 +17,7 @@ from pipecat.metrics.metrics import ( LLMUsageMetricsData, MetricsData, ProcessingMetricsData, + TextAggregationMetricsData, TTFBMetricsData, TTSUsageMetricsData, ) @@ -43,6 +44,7 @@ class FrameProcessorMetrics(BaseObject): self._task_manager = None self._start_ttfb_time = 0 self._start_processing_time = 0 + self._start_text_aggregation_time = 0 self._last_ttfb_time = 0 self._should_report_ttfb = True @@ -211,3 +213,24 @@ class FrameProcessorMetrics(BaseObject): ) logger.debug(f"{self._processor_name()} usage characters: {characters.value}") return MetricsFrame(data=[characters]) + + async def start_text_aggregation_metrics(self): + """Start measuring text aggregation time (first token to first sentence).""" + self._start_text_aggregation_time = time.time() + + async def stop_text_aggregation_metrics(self): + """Stop text aggregation measurement and generate metrics frame. + + Returns: + MetricsFrame containing text aggregation time, or None if not measuring. + """ + if self._start_text_aggregation_time == 0: + return None + + value = time.time() - self._start_text_aggregation_time + logger.debug(f"{self._processor_name()} text aggregation time: {value}") + aggregation = TextAggregationMetricsData( + processor=self._processor_name(), value=value, model=self._model_name() + ) + self._start_text_aggregation_time = 0 + return MetricsFrame(data=[aggregation]) diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index 334f80d80..4f1fd5a58 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -28,7 +28,7 @@ from pipecat.frames.frames import ( ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven -from pipecat.services.tts_service import AudioContextTTSService, TTSService +from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -128,7 +128,8 @@ class AsyncAITTSService(AudioContextTTSService): encoding: str = "pcm_s16le", container: str = "raw", params: Optional[InputParams] = None, - aggregate_sentences: Optional[bool] = True, + aggregate_sentences: Optional[bool] = None, + text_aggregation_mode: Optional[TextAggregationMode] = None, **kwargs, ): """Initialize the Async TTS service. @@ -144,13 +145,19 @@ class AsyncAITTSService(AudioContextTTSService): encoding: Audio encoding format. container: Audio container format. params: Additional input parameters for voice customization. - aggregate_sentences: Whether to aggregate sentences within the TTSService. + aggregate_sentences: Deprecated. Use text_aggregation_mode instead. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + + text_aggregation_mode: How to aggregate text before synthesis. **kwargs: Additional arguments passed to the parent service. """ params = params or AsyncAITTSService.InputParams() super().__init__( aggregate_sentences=aggregate_sentences, + text_aggregation_mode=text_aggregation_mode, pause_frame_processing=True, push_stop_frames=True, sample_rate=sample_rate, diff --git a/src/pipecat/services/azure/tts.py b/src/pipecat/services/azure/tts.py index b3534b28e..f68694eb5 100644 --- a/src/pipecat/services/azure/tts.py +++ b/src/pipecat/services/azure/tts.py @@ -27,7 +27,7 @@ from pipecat.frames.frames import ( from pipecat.processors.frame_processor import FrameDirection from pipecat.services.azure.common import language_to_azure_language from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven -from pipecat.services.tts_service import TTSService +from pipecat.services.tts_service import TextAggregationMode, TTSService from pipecat.transcriptions.language import Language from pipecat.utils.tracing.service_decorators import traced_tts @@ -256,7 +256,8 @@ class AzureTTSService(TTSService, AzureBaseTTSService): voice: str = "en-US-SaraNeural", sample_rate: Optional[int] = None, params: Optional[AzureBaseTTSService.InputParams] = None, - aggregate_sentences: bool = True, + aggregate_sentences: Optional[bool] = None, + text_aggregation_mode: Optional[TextAggregationMode] = None, **kwargs, ): """Initialize the Azure streaming TTS service. @@ -267,13 +268,19 @@ class AzureTTSService(TTSService, AzureBaseTTSService): voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural". sample_rate: Audio sample rate in Hz. If None, uses service default. params: Voice and synthesis parameters configuration. - aggregate_sentences: Whether to aggregate sentences before synthesis. - **kwargs: Additional arguments passed to the parent TTSService. + aggregate_sentences: Deprecated. Use text_aggregation_mode instead. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + + text_aggregation_mode: How to aggregate text before synthesis. + **kwargs: Additional arguments passed to parent WordTTSService. """ params = params or AzureBaseTTSService.InputParams() super().__init__( aggregate_sentences=aggregate_sentences, + text_aggregation_mode=text_aggregation_mode, push_text_frames=False, # We'll push text frames based on word timestamps push_stop_frames=True, pause_frame_processing=True, diff --git a/src/pipecat/services/cartesia/tts.py b/src/pipecat/services/cartesia/tts.py index edf838e59..2e637c339 100644 --- a/src/pipecat/services/cartesia/tts.py +++ b/src/pipecat/services/cartesia/tts.py @@ -27,7 +27,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven -from pipecat.services.tts_service import AudioContextTTSService, TTSService +from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.text.base_text_aggregator import BaseTextAggregator from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator @@ -272,7 +272,8 @@ class CartesiaTTSService(AudioContextTTSService): container: str = "raw", params: Optional[InputParams] = None, text_aggregator: Optional[BaseTextAggregator] = None, - aggregate_sentences: Optional[bool] = True, + text_aggregation_mode: Optional[TextAggregationMode] = None, + aggregate_sentences: Optional[bool] = None, **kwargs, ): """Initialize the Cartesia TTS service. @@ -292,13 +293,21 @@ class CartesiaTTSService(AudioContextTTSService): .. deprecated:: 0.0.95 Use an LLMTextProcessor before the TTSService for custom text aggregation. + text_aggregation_mode: How to aggregate incoming text before synthesis. aggregate_sentences: Whether to aggregate sentences within the TTSService. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + **kwargs: Additional arguments passed to the parent service. """ - # Aggregating sentences still gives cleaner-sounding results and fewer - # artifacts than streaming one word at a time. On average, waiting for a - # full sentence should only "cost" us 15ms or so with GPT-4o or a Llama - # 3 model, and it's worth it for the better audio quality. + # By default, we aggregate sentences before sending to TTS. This adds + # ~200-300ms of latency per sentence (waiting for the sentence-ending + # punctuation token from the LLM). Setting + # text_aggregation_mode=TextAggregationMode.TOKEN streams tokens + # directly, which reduces latency. Streaming quality is good but less + # tested than sentence aggregation. + # TODO: Consider making TOKEN the default for Cartesia in 1.0. # # We also don't want to automatically push LLM response text frames, # because the context aggregators will add them to the LLM context even @@ -308,6 +317,7 @@ class CartesiaTTSService(AudioContextTTSService): params = params or CartesiaTTSService.InputParams() super().__init__( + text_aggregation_mode=text_aggregation_mode, aggregate_sentences=aggregate_sentences, push_text_frames=False, pause_frame_processing=True, @@ -337,7 +347,9 @@ class CartesiaTTSService(AudioContextTTSService): # The preferred way of taking advantage of Cartesia SSML Tags is # to use an LLMTextProcessor and/or a text_transformer to identify # and insert these tags for the purpose of the TTS service alone. - self._text_aggregator = SkipTagsAggregator([("", "")]) + self._text_aggregator = SkipTagsAggregator( + [("", "")], aggregation_type=self._text_aggregation_mode + ) self._api_key = api_key self._cartesia_version = cartesia_version @@ -639,7 +651,10 @@ class CartesiaTTSService(AudioContextTTSService): Yields: Frame: Audio frames containing the synthesized speech. """ - logger.debug(f"{self}: Generating TTS [{text}]") + if not self._is_streaming_tokens: + logger.debug(f"{self}: Generating TTS [{text}]") + else: + logger.trace(f"{self}: Generating TTS [{text}]") try: if not self._websocket or self._websocket.state is State.CLOSED: diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index c68d005f1..1811ed971 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -47,6 +47,7 @@ from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven from pipecat.services.tts_service import ( AudioContextTTSService, + TextAggregationMode, TTSService, ) from pipecat.transcriptions.language import Language, resolve_language @@ -365,7 +366,8 @@ class ElevenLabsTTSService(AudioContextTTSService): url: str = "wss://api.elevenlabs.io", sample_rate: Optional[int] = None, params: Optional[InputParams] = None, - aggregate_sentences: Optional[bool] = True, + text_aggregation_mode: Optional[TextAggregationMode] = None, + aggregate_sentences: Optional[bool] = None, **kwargs, ): """Initialize the ElevenLabs TTS service. @@ -377,13 +379,20 @@ class ElevenLabsTTSService(AudioContextTTSService): url: WebSocket URL for ElevenLabs TTS API. sample_rate: Audio sample rate. If None, uses default. params: Additional input parameters for voice customization. + text_aggregation_mode: How to aggregate incoming text before synthesis. aggregate_sentences: Whether to aggregate sentences within the TTSService. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + **kwargs: Additional arguments passed to the parent service. """ - # Aggregating sentences still gives cleaner-sounding results and fewer - # artifacts than streaming one word at a time. On average, waiting for a - # full sentence should only "cost" us 15ms or so with GPT-4o or a Llama - # 3 model, and it's worth it for the better audio quality. + # By default, we aggregate sentences before sending to TTS. This adds + # ~200-300ms of latency per sentence (waiting for the sentence-ending + # punctuation token from the LLM). Setting + # text_aggregation_mode=TextAggregationMode.TOKEN streams tokens + # directly. To use this mode, you must set auto_mode=False. This + # eliminates aggregation time, but slows down ElevenLabs. # # We also don't want to automatically push LLM response text frames, # because the context aggregators will add them to the LLM context even @@ -397,6 +406,7 @@ class ElevenLabsTTSService(AudioContextTTSService): params = params or ElevenLabsTTSService.InputParams() super().__init__( + text_aggregation_mode=text_aggregation_mode, aggregate_sentences=aggregate_sentences, push_text_frames=False, push_stop_frames=True, @@ -893,7 +903,8 @@ class ElevenLabsHttpTTSService(TTSService): base_url: str = "https://api.elevenlabs.io", sample_rate: Optional[int] = None, params: Optional[InputParams] = None, - aggregate_sentences: Optional[bool] = True, + text_aggregation_mode: Optional[TextAggregationMode] = None, + aggregate_sentences: Optional[bool] = None, **kwargs, ): """Initialize the ElevenLabs HTTP TTS service. @@ -906,12 +917,18 @@ class ElevenLabsHttpTTSService(TTSService): base_url: Base URL for ElevenLabs HTTP API. sample_rate: Audio sample rate. If None, uses default. params: Additional input parameters for voice customization. + text_aggregation_mode: How to aggregate incoming text before synthesis. aggregate_sentences: Whether to aggregate sentences within the TTSService. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + **kwargs: Additional arguments passed to the parent service. """ params = params or ElevenLabsHttpTTSService.InputParams() super().__init__( + text_aggregation_mode=text_aggregation_mode, aggregate_sentences=aggregate_sentences, push_text_frames=False, push_stop_frames=True, diff --git a/src/pipecat/services/inworld/tts.py b/src/pipecat/services/inworld/tts.py index 2fb86b4a6..d3f64c16f 100644 --- a/src/pipecat/services/inworld/tts.py +++ b/src/pipecat/services/inworld/tts.py @@ -51,7 +51,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import AudioContextTTSService, TTSService +from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService from pipecat.utils.tracing.service_decorators import traced_tts @@ -509,7 +509,8 @@ class InworldTTSService(AudioContextTTSService): sample_rate: Optional[int] = None, encoding: str = "LINEAR16", params: InputParams = None, - aggregate_sentences: bool = True, + aggregate_sentences: Optional[bool] = None, + text_aggregation_mode: Optional[TextAggregationMode] = None, append_trailing_space: bool = True, **kwargs: Any, ): @@ -523,7 +524,12 @@ class InworldTTSService(AudioContextTTSService): sample_rate: Audio sample rate in Hz. encoding: Audio encoding format. params: Input parameters for Inworld WebSocket TTS configuration. - aggregate_sentences: Whether to aggregate sentences before synthesis. + aggregate_sentences: Deprecated. Use text_aggregation_mode instead. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + + text_aggregation_mode: How to aggregate text before synthesis. append_trailing_space: Whether to append a trailing space to text before sending to TTS. **kwargs: Additional arguments passed to the parent class. """ @@ -536,6 +542,7 @@ class InworldTTSService(AudioContextTTSService): supports_word_timestamps=True, sample_rate=sample_rate, aggregate_sentences=aggregate_sentences, + text_aggregation_mode=text_aggregation_mode, append_trailing_space=append_trailing_space, settings=InworldTTSSettings( model=model, diff --git a/src/pipecat/services/neuphonic/tts.py b/src/pipecat/services/neuphonic/tts.py index 81b366a8b..63411c3eb 100644 --- a/src/pipecat/services/neuphonic/tts.py +++ b/src/pipecat/services/neuphonic/tts.py @@ -36,7 +36,7 @@ from pipecat.frames.frames import ( ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven -from pipecat.services.tts_service import InterruptibleTTSService, TTSService +from pipecat.services.tts_service import InterruptibleTTSService, TextAggregationMode, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -119,7 +119,8 @@ class NeuphonicTTSService(InterruptibleTTSService): sample_rate: Optional[int] = 22050, encoding: str = "pcm_linear", params: Optional[InputParams] = None, - aggregate_sentences: Optional[bool] = True, + aggregate_sentences: Optional[bool] = None, + text_aggregation_mode: Optional[TextAggregationMode] = None, **kwargs, ): """Initialize the Neuphonic TTS service. @@ -131,13 +132,19 @@ class NeuphonicTTSService(InterruptibleTTSService): sample_rate: Audio sample rate in Hz. Defaults to 22050. encoding: Audio encoding format. Defaults to "pcm_linear". params: Additional input parameters for TTS configuration. - aggregate_sentences: Whether to aggregate sentences within the TTSService. + aggregate_sentences: Deprecated. Use text_aggregation_mode instead. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + + text_aggregation_mode: How to aggregate text before synthesis. **kwargs: Additional arguments passed to parent InterruptibleTTSService. """ params = params or NeuphonicTTSService.InputParams() super().__init__( aggregate_sentences=aggregate_sentences, + text_aggregation_mode=text_aggregation_mode, push_stop_frames=True, stop_frame_timeout_s=2.0, sample_rate=sample_rate, diff --git a/src/pipecat/services/rime/tts.py b/src/pipecat/services/rime/tts.py index 059db8178..d5f97e028 100644 --- a/src/pipecat/services/rime/tts.py +++ b/src/pipecat/services/rime/tts.py @@ -35,6 +35,7 @@ from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven from pipecat.services.tts_service import ( AudioContextTTSService, InterruptibleTTSService, + TextAggregationMode, TTSService, ) from pipecat.transcriptions.language import Language, resolve_language @@ -181,7 +182,8 @@ class RimeTTSService(AudioContextTTSService): sample_rate: Optional[int] = None, params: Optional[InputParams] = None, text_aggregator: Optional[BaseTextAggregator] = None, - aggregate_sentences: Optional[bool] = True, + text_aggregation_mode: Optional[TextAggregationMode] = None, + aggregate_sentences: Optional[bool] = None, **kwargs, ): """Initialize Rime TTS service. @@ -198,13 +200,19 @@ class RimeTTSService(AudioContextTTSService): .. deprecated:: 0.0.95 Use an LLMTextProcessor before the TTSService for custom text aggregation. - aggregate_sentences: Whether to aggregate sentences within the TTSService. + text_aggregation_mode: How to aggregate incoming text before synthesis. + aggregate_sentences: Deprecated. Use text_aggregation_mode instead. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + **kwargs: Additional arguments passed to parent class. """ # Initialize with parent class settings for proper frame handling params = params or RimeTTSService.InputParams() super().__init__( + text_aggregation_mode=text_aggregation_mode, aggregate_sentences=aggregate_sentences, push_text_frames=False, push_stop_frames=True, @@ -243,7 +251,9 @@ class RimeTTSService(AudioContextTTSService): # The preferred way of taking advantage of Rime spelling is # to use an LLMTextProcessor and/or a text_transformer to identify # and insert these tags for the purpose of the TTS service alone. - self._text_aggregator = SkipTagsAggregator([("spell(", ")")]) + self._text_aggregator = SkipTagsAggregator( + [("spell(", ")")], aggregation_type=self._text_aggregation_mode + ) # Store service configuration self._api_key = api_key @@ -826,7 +836,8 @@ class RimeNonJsonTTSService(InterruptibleTTSService): audio_format: str = "pcm", sample_rate: Optional[int] = None, params: Optional[InputParams] = None, - aggregate_sentences: Optional[bool] = True, + aggregate_sentences: Optional[bool] = None, + text_aggregation_mode: Optional[TextAggregationMode] = None, **kwargs, ): """Initialize Rime Non-JSON WebSocket TTS service. @@ -839,13 +850,21 @@ class RimeNonJsonTTSService(InterruptibleTTSService): audio_format: Audio format to use. sample_rate: Audio sample rate in Hz. params: Additional configuration parameters. - aggregate_sentences: Whether to aggregate sentences within the TTSService. + aggregate_sentences: Deprecated. Use text_aggregation_mode instead. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. Set to ``TextAggregationMode.SENTENCE`` + to aggregate text into sentences before synthesis, or + ``TextAggregationMode.TOKEN`` to stream tokens directly for lower latency. + + text_aggregation_mode: How to aggregate text before synthesis. **kwargs: Additional arguments passed to parent class. """ params = params or RimeNonJsonTTSService.InputParams() super().__init__( sample_rate=sample_rate, aggregate_sentences=aggregate_sentences, + text_aggregation_mode=text_aggregation_mode, push_stop_frames=True, pause_frame_processing=True, append_trailing_space=True, diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py index 7b63828a1..87604a9f9 100644 --- a/src/pipecat/services/sarvam/tts.py +++ b/src/pipecat/services/sarvam/tts.py @@ -63,7 +63,7 @@ from pipecat.frames.frames import ( from pipecat.processors.frame_processor import FrameDirection from pipecat.services.sarvam._sdk import sdk_headers from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven -from pipecat.services.tts_service import InterruptibleTTSService, TTSService +from pipecat.services.tts_service import InterruptibleTTSService, TextAggregationMode, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -785,7 +785,8 @@ class SarvamTTSService(InterruptibleTTSService): model: str = "bulbul:v2", voice_id: Optional[str] = None, url: str = "wss://api.sarvam.ai/text-to-speech/ws", - aggregate_sentences: Optional[bool] = True, + aggregate_sentences: Optional[bool] = None, + text_aggregation_mode: Optional[TextAggregationMode] = None, sample_rate: Optional[int] = None, params: Optional[InputParams] = None, **kwargs, @@ -799,7 +800,12 @@ class SarvamTTSService(InterruptibleTTSService): - "bulbul:v3-beta": Advanced model with temperature control voice_id: Speaker voice ID. If None, uses model-appropriate default. url: WebSocket URL for the TTS backend (default production URL). - aggregate_sentences: Merge multiple sentences into one audio chunk (default True). + aggregate_sentences: Deprecated. Use text_aggregation_mode instead. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. + + text_aggregation_mode: How to aggregate text before synthesis. sample_rate: Output audio sample rate in Hz (8000, 16000, 22050, 24000). If None, uses model-specific default. params: Optional input parameters to override defaults. @@ -834,6 +840,7 @@ class SarvamTTSService(InterruptibleTTSService): # Initialize parent class first super().__init__( aggregate_sentences=aggregate_sentences, + text_aggregation_mode=text_aggregation_mode, push_text_frames=True, pause_frame_processing=True, push_stop_frames=True, diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index e36d4754f..c6d2672d6 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -11,6 +11,7 @@ import uuid import warnings from abc import abstractmethod from dataclasses import dataclass +from enum import Enum from typing import ( Any, AsyncGenerator, @@ -72,6 +73,23 @@ class TTSContext: append_to_context: bool = True +class TextAggregationMode(str, Enum): + """Controls how incoming text is aggregated before TTS synthesis. + + Parameters: + SENTENCE: Buffer text until sentence boundaries are detected before synthesis. + Produces more natural speech but adds latency (~200-300ms per sentence). + TOKEN: Stream text tokens directly to TTS as they arrive. + Reduces latency but may affect speech quality depending on the TTS provider. + """ + + SENTENCE = "sentence" + TOKEN = "token" + + def __str__(self): + return self.value + + class TTSService(AIService): """Base class for text-to-speech services. @@ -109,7 +127,8 @@ class TTSService(AIService): def __init__( self, *, - aggregate_sentences: bool = True, + text_aggregation_mode: Optional[TextAggregationMode] = None, + aggregate_sentences: Optional[bool] = None, # if True, TTSService will push TextFrames and LLMFullResponseEndFrames, # otherwise subclass must do it push_text_frames: bool = True, @@ -153,7 +172,16 @@ class TTSService(AIService): """Initialize the TTS service. Args: + text_aggregation_mode: How to aggregate incoming text before synthesis. + TextAggregationMode.SENTENCE (default) buffers until sentence boundaries, + TextAggregationMode.TOKEN streams tokens directly for lower latency. aggregate_sentences: Whether to aggregate text into sentences before synthesis. + + .. deprecated:: 0.0.104 + Use ``text_aggregation_mode`` instead. Set to ``TextAggregationMode.SENTENCE`` + to aggregate text into sentences before synthesis, or + ``TextAggregationMode.TOKEN`` to stream tokens directly for lower latency. + push_text_frames: Whether to push TextFrames and LLMFullResponseEndFrames. push_stop_frames: Whether to automatically push TTSStoppedFrames. stop_frame_timeout_s: Idle time before pushing TTSStoppedFrame when push_stop_frames is True. @@ -194,7 +222,31 @@ class TTSService(AIService): or TTSSettings(), **kwargs, ) - self._aggregate_sentences: bool = aggregate_sentences + + # Resolve text_aggregation_mode from the new param or deprecated aggregate_sentences + if aggregate_sentences is not None: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Parameter 'aggregate_sentences' is deprecated. " + "Use 'text_aggregation_mode=TextAggregationMode.SENTENCE' or " + "'text_aggregation_mode=TextAggregationMode.TOKEN' instead.", + DeprecationWarning, + stacklevel=2, + ) + if text_aggregation_mode is None: + text_aggregation_mode = ( + TextAggregationMode.SENTENCE + if aggregate_sentences + else TextAggregationMode.TOKEN + ) + + if text_aggregation_mode is None: + text_aggregation_mode = TextAggregationMode.SENTENCE + + self._text_aggregation_mode: TextAggregationMode = text_aggregation_mode self._push_text_frames: bool = push_text_frames self._push_stop_frames: bool = push_stop_frames self._stop_frame_timeout_s: float = stop_frame_timeout_s @@ -204,7 +256,9 @@ class TTSService(AIService): self._append_trailing_space: bool = append_trailing_space self._init_sample_rate = sample_rate self._sample_rate = 0 - self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator() + self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator( + aggregation_type=self._text_aggregation_mode + ) if text_aggregator: import warnings @@ -240,6 +294,8 @@ class TTSService(AIService): self._processing_text: bool = False self._tts_contexts: Dict[str, TTSContext] = {} + self._streamed_text: str = "" + self._text_aggregation_metrics_started: bool = False # Word timestamp state (active when supports_word_timestamps=True) self._supports_word_timestamps: bool = supports_word_timestamps @@ -253,6 +309,40 @@ class TTSService(AIService): self._register_event_handler("on_connection_error") self._register_event_handler("on_tts_request") + @property + def _is_streaming_tokens(self) -> bool: + """Whether the service is streaming tokens directly without sentence aggregation.""" + return self._text_aggregation_mode == TextAggregationMode.TOKEN + + async def start_tts_usage_metrics(self, text: str): + """Record TTS usage metrics. + + When streaming tokens, usage metrics are aggregated and reported at + flush time instead of per token, so individual calls are skipped. + + Args: + text: The text being processed by TTS. + """ + if self._is_streaming_tokens: + return + await super().start_tts_usage_metrics(text) + + async def start_text_aggregation_metrics(self): + """Start text aggregation metrics if not already started. + + Only starts the metric once per LLM response. Skipped when streaming + tokens since per-token aggregation time is not meaningful. + """ + if self._is_streaming_tokens or self._text_aggregation_metrics_started: + return + self._text_aggregation_metrics_started = True + await super().start_text_aggregation_metrics() + + async def stop_text_aggregation_metrics(self): + """Stop text aggregation metrics and reset the started flag.""" + self._text_aggregation_metrics_started = False + await super().stop_text_aggregation_metrics() + @property def sample_rate(self) -> int: """Get the current sample rate for audio output. @@ -511,6 +601,7 @@ class TTSService(AIService): and not isinstance(frame, InterimTranscriptionFrame) and not isinstance(frame, TranscriptionFrame) ): + await self.start_text_aggregation_metrics() await self._process_text_frame(frame) elif isinstance(frame, InterruptionFrame): await self._handle_interruption(frame, direction) @@ -526,9 +617,17 @@ class TTSService(AIService): # Flush any remaining text (including text waiting for lookahead) remaining = await self._text_aggregator.flush() + # Stop the aggregation metric (no-op if already stopped on first sentence). + await self.stop_text_aggregation_metrics() if remaining: await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type)) + # Log accumulated streamed text and emit aggregated usage metric. + if self._streamed_text: + logger.debug(f"{self}: Generating TTS [{self._streamed_text}]") + await super().start_tts_usage_metrics(self._streamed_text) + self._streamed_text = "" + # Reset aggregator state self._processing_text = False if isinstance(frame, LLMFullResponseEndFrame): @@ -678,6 +777,8 @@ class TTSService(AIService): await filter.handle_interruption() self._llm_response_started = False + self._streamed_text = "" + self._text_aggregation_metrics_started = False if self._supports_word_timestamps: await self.reset_word_timestamps() @@ -690,26 +791,18 @@ class TTSService(AIService): await self.resume_processing_frames() async def _process_text_frame(self, frame: TextFrame): - text: Optional[str] = None - includes_inter_frame_spaces: bool = False - if not self._aggregate_sentences: - text = frame.text - includes_inter_frame_spaces = frame.includes_inter_frame_spaces - aggregated_by = "token" - - if text: - logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}") - await self._push_tts_frames( - AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces - ) - else: - async for aggregate in self._text_aggregator.aggregate(frame.text): - text = aggregate.text - aggregated_by = aggregate.type - logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}") - await self._push_tts_frames( - AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces - ) + async for aggregate in self._text_aggregator.aggregate(frame.text): + includes_inter_frame_spaces = ( + frame.includes_inter_frame_spaces + if aggregate.type == AggregationType.TOKEN + else False + ) + if aggregate.type != AggregationType.TOKEN: + # Stop the aggregation metric on the first sentence only. + await self.stop_text_aggregation_metrics() + await self._push_tts_frames( + AggregatedTextFrame(aggregate.text, aggregate.type), includes_inter_frame_spaces + ) async def _push_tts_frames( self, @@ -739,7 +832,15 @@ class TTSService(AIService): # or when we received an LLMFullResponseEndFrame self._processing_text = True - await self.start_processing_metrics() + # Accumulate text for a single debug log at flush time when streaming tokens. + if self._is_streaming_tokens: + self._streamed_text += text + + # Skip per-token processing metrics when streaming. The per-token + # processing time is just websocket send overhead (~0.1ms) and not + # meaningful. TTFB captures the important timing for streaming TTS. + if not self._is_streaming_tokens: + await self.start_processing_metrics() # Process all filters. for filter in self._text_filters: @@ -747,7 +848,8 @@ class TTSService(AIService): text = await filter.filter(text) if not text.strip(): - await self.stop_processing_metrics() + if not self._is_streaming_tokens: + await self.stop_processing_metrics() return # Create context ID and store metadata @@ -785,7 +887,8 @@ class TTSService(AIService): await self.process_generator(self.run_tts(prepared_text, context_id)) - await self.stop_processing_metrics() + if not self._is_streaming_tokens: + await self.stop_processing_metrics() if self._push_text_frames: # In TTS services that support word timestamps, the TTSTextFrames diff --git a/src/pipecat/utils/text/base_text_aggregator.py b/src/pipecat/utils/text/base_text_aggregator.py index 13691d9cd..2b050fcb7 100644 --- a/src/pipecat/utils/text/base_text_aggregator.py +++ b/src/pipecat/utils/text/base_text_aggregator.py @@ -21,6 +21,7 @@ class AggregationType(str, Enum): """Built-in aggregation strings.""" SENTENCE = "sentence" + TOKEN = "token" WORD = "word" def __str__(self): @@ -66,6 +67,25 @@ class BaseTextAggregator(ABC): logic, text manipulation behavior, and state management for interruptions. """ + def __init__(self, *, aggregation_type: AggregationType = AggregationType.SENTENCE): + """Initialize the base text aggregator. + + Args: + aggregation_type: The aggregation strategy to use. SENTENCE buffers + text until sentence boundaries are detected, TOKEN passes text + through immediately, and WORD buffers until word boundaries. + """ + self._aggregation_type = AggregationType(aggregation_type) + + @property + def aggregation_type(self) -> AggregationType: + """Get the aggregation type for this aggregator. + + Returns: + The aggregation type. + """ + return self._aggregation_type + @property @abstractmethod def text(self) -> Aggregation: diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py index bfaf9291b..835bb8591 100644 --- a/src/pipecat/utils/text/pattern_pair_aggregator.py +++ b/src/pipecat/utils/text/pattern_pair_aggregator.py @@ -96,8 +96,11 @@ class PatternPairAggregator(SimpleTextAggregator): Creates an empty aggregator with no patterns or handlers registered. Text buffering and pattern detection will begin when text is aggregated. + + Args: + **kwargs: Additional arguments passed to SimpleTextAggregator (e.g. aggregation_type). """ - super().__init__() + super().__init__(**kwargs) self._patterns = {} self._handlers = {} self._last_processed_position = 0 # Track where we last checked for complete patterns @@ -146,7 +149,7 @@ class PatternPairAggregator(SimpleTextAggregator): Returns: Self for method chaining. """ - if type in [AggregationType.SENTENCE, AggregationType.WORD]: + if type in [AggregationType.SENTENCE, AggregationType.WORD, AggregationType.TOKEN]: raise ValueError( f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns." ) @@ -321,6 +324,9 @@ class PatternPairAggregator(SimpleTextAggregator): and uses the parent's lookahead logic for sentence detection when no patterns are active. + In TOKEN mode, pattern detection still works but non-pattern text is + yielded as TOKEN aggregations instead of waiting for sentence boundaries. + Args: text: Text to aggregate. @@ -370,18 +376,35 @@ class PatternPairAggregator(SimpleTextAggregator): # boundaries when a pattern begins (e.g., "Here is code ..." yields "Here is code") result = self._text[: pattern_start[0]] self._text = self._text[pattern_start[0] :] - yield PatternMatch( - content=result.strip(), type=AggregationType.SENTENCE, full_match=result + agg_type = ( + AggregationType.TOKEN + if self._aggregation_type == AggregationType.TOKEN + else AggregationType.SENTENCE ) + yield PatternMatch(content=result.strip(), type=agg_type, full_match=result) continue - # Use parent's lookahead logic for sentence detection - aggregation = await super()._check_sentence_with_lookahead(char) - if aggregation: - # Convert to PatternMatch for consistency with return type + if self._aggregation_type != AggregationType.TOKEN: + # Use parent's lookahead logic for sentence detection + aggregation = await super()._check_sentence_with_lookahead(char) + if aggregation: + # Convert to PatternMatch for consistency with return type + yield PatternMatch( + content=aggregation.text, + type=aggregation.type, + full_match=aggregation.text, + ) + + # In TOKEN mode, yield any accumulated text after processing all chars, + # but only if there's no incomplete pattern being buffered. + if self._aggregation_type == AggregationType.TOKEN and self._text: + if self._match_start_of_pattern(self._text) is None: yield PatternMatch( - content=aggregation.text, type=aggregation.type, full_match=aggregation.text + content=self._text, + type=AggregationType.TOKEN, + full_match=self._text, ) + self._text = "" async def handle_interruption(self): """Handle interruptions by clearing the buffer and pattern state. diff --git a/src/pipecat/utils/text/simple_text_aggregator.py b/src/pipecat/utils/text/simple_text_aggregator.py index b0cc698a9..b5b179fcf 100644 --- a/src/pipecat/utils/text/simple_text_aggregator.py +++ b/src/pipecat/utils/text/simple_text_aggregator.py @@ -25,11 +25,15 @@ class SimpleTextAggregator(BaseTextAggregator): most straightforward implementation of text aggregation for TTS processing. """ - def __init__(self): + def __init__(self, **kwargs): """Initialize the simple text aggregator. Creates an empty text buffer ready to begin accumulating text tokens. + + Args: + **kwargs: Additional arguments passed to BaseTextAggregator (e.g. aggregation_type). """ + super().__init__(**kwargs) self._text = "" self._needs_lookahead: bool = False @@ -43,19 +47,25 @@ class SimpleTextAggregator(BaseTextAggregator): return Aggregation(text=self._text.strip(" "), type=AggregationType.SENTENCE) async def aggregate(self, text: str) -> AsyncIterator[Aggregation]: - """Aggregate text and yield completed sentences. + """Aggregate text and yield completed aggregations. - Processes the input text character-by-character. When sentence-ending - punctuation is detected, it waits for non-whitespace lookahead before - calling NLTK. This prevents false positives like "$29." being detected - as a sentence when it's actually "$29.95". + In SENTENCE mode, processes the input text character-by-character. When + sentence-ending punctuation is detected, it waits for non-whitespace + lookahead before calling NLTK. + + In TOKEN mode, yields the text immediately without buffering. Args: text: Text to aggregate. Yields: - Complete sentences as Aggregation objects. + Aggregation objects (sentences in SENTENCE mode, tokens in TOKEN mode). """ + if self._aggregation_type == AggregationType.TOKEN: + if text: + yield Aggregation(text=text, type=AggregationType.TOKEN) + return + # Process text character by character for char in text: self._text += char @@ -114,11 +124,15 @@ class SimpleTextAggregator(BaseTextAggregator): """Flush any remaining text in the buffer. Returns any text remaining in the buffer. This is called at the end - of a stream to ensure all text is processed. + of a stream to ensure all text is processed. In TOKEN mode, returns + None since tokens are yielded immediately. Returns: - Any remaining text as a sentence, or None if buffer is empty. + Any remaining text as a sentence, or None if buffer is empty or in TOKEN mode. """ + if self._aggregation_type == AggregationType.TOKEN: + return None + if self._text: # Return whatever we have in the buffer result = self._text diff --git a/src/pipecat/utils/text/skip_tags_aggregator.py b/src/pipecat/utils/text/skip_tags_aggregator.py index 4232efd7d..1b6a7f156 100644 --- a/src/pipecat/utils/text/skip_tags_aggregator.py +++ b/src/pipecat/utils/text/skip_tags_aggregator.py @@ -14,7 +14,7 @@ as a unit regardless of internal punctuation. from typing import AsyncIterator, Optional, Sequence from pipecat.utils.string import StartEndTags, parse_start_end_tags -from pipecat.utils.text.base_text_aggregator import Aggregation +from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator @@ -31,14 +31,15 @@ class SkipTagsAggregator(SimpleTextAggregator): identified and that content within tags is never split at sentence boundaries. """ - def __init__(self, tags: Sequence[StartEndTags]): + def __init__(self, tags: Sequence[StartEndTags], **kwargs): """Initialize the skip tags aggregator. Args: tags: Sequence of StartEndTags objects defining the tag pairs that should prevent sentence boundary detection. + **kwargs: Additional arguments passed to SimpleTextAggregator (e.g. aggregation_type). """ - super().__init__() + super().__init__(**kwargs) self._tags = tags self._current_tag: Optional[StartEndTags] = None self._current_tag_index: int = 0 @@ -50,13 +51,33 @@ class SkipTagsAggregator(SimpleTextAggregator): uses the parent's lookahead logic for sentence detection when not inside tags. + In TOKEN mode, text is passed through immediately unless we're inside + a tag, in which case we buffer until the closing tag is found. + Args: text: Text to aggregate. Yields: Aggregation objects containing text up to a sentence boundary, - marked as SENTENCE type. + marked as SENTENCE type (or TOKEN type in TOKEN mode). """ + if self._aggregation_type == AggregationType.TOKEN: + # In TOKEN mode, process chars for tag tracking but yield the + # full input as a single token when not inside a tag. + for char in text: + self._text += char + + # Update tag state + (self._current_tag, self._current_tag_index) = parse_start_end_tags( + self._text, self._tags, self._current_tag, self._current_tag_index + ) + + # After processing all chars: if not inside a tag, yield accumulated text + if not self._current_tag and self._text: + yield Aggregation(text=self._text, type=AggregationType.TOKEN) + self._text = "" + return + # Process text character by character for char in text: self._text += char diff --git a/tests/test_pattern_pair_aggregator.py b/tests/test_pattern_pair_aggregator.py index bcc8d18f7..6c9e23552 100644 --- a/tests/test_pattern_pair_aggregator.py +++ b/tests/test_pattern_pair_aggregator.py @@ -194,5 +194,66 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase): self.assertEqual(self.aggregator.text.text, "") +class TestPatternPairAggregatorTokenMode(unittest.IsolatedAsyncioTestCase): + def setUp(self): + from pipecat.utils.text.base_text_aggregator import AggregationType + + self.aggregator = PatternPairAggregator(aggregation_type=AggregationType.TOKEN) + self.handler = AsyncMock() + self.aggregator.add_pattern( + type="think", + start_pattern="", + end_pattern="", + action=MatchAction.REMOVE, + ) + self.aggregator.on_pattern_match("think", self.handler) + + async def test_token_no_patterns(self): + """Non-pattern text passes through as TOKEN, one per aggregate call.""" + results = [] + for token in ["Hello", " world", "."]: + async for r in self.aggregator.aggregate(token): + results.append(r) + + self.assertEqual(len(results), 3) + self.assertEqual(results[0].text, "Hello") + self.assertEqual(results[1].text, " world") + self.assertEqual(results[2].text, ".") + for r in results: + self.assertEqual(r.type, "token") + + async def test_token_pattern_detection(self): + """Pattern detection still works with word-by-word token delivery.""" + results = [] + for token in ["Hi ", "", "secret", "", " bye"]: + async for r in self.aggregator.aggregate(token): + results.append(r) + + # Handler called once when the pattern completes + self.handler.assert_called_once() + call_args = self.handler.call_args[0][0] + self.assertEqual(call_args.text, "secret") + + # "Hi " yields before pattern starts, pattern is removed, " bye" yields after + self.assertEqual(len(results), 2) + self.assertEqual(results[0].text, "Hi ") + self.assertEqual(results[0].type, "token") + self.assertEqual(results[1].text, " bye") + self.assertEqual(results[1].type, "token") + + async def test_token_incomplete_pattern_buffers(self): + """Incomplete pattern is buffered across calls, not leaked to output.""" + results = [] + for token in ["Hi ", "", "partial"]: + async for r in self.aggregator.aggregate(token): + results.append(r) + + # Only "Hi " should be yielded; "partial" stays buffered + self.assertEqual(len(results), 1) + self.assertEqual(results[0].text, "Hi ") + self.assertEqual(results[0].type, "token") + self.handler.assert_not_called() + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_simple_text_aggregator.py b/tests/test_simple_text_aggregator.py index 4b3613e27..46c77df42 100644 --- a/tests/test_simple_text_aggregator.py +++ b/tests/test_simple_text_aggregator.py @@ -181,5 +181,39 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase): assert result.text == "こんにちは。" +class TestSimpleTextAggregatorTokenMode(unittest.IsolatedAsyncioTestCase): + def setUp(self): + from pipecat.utils.text.base_text_aggregator import AggregationType + + self.aggregator = SimpleTextAggregator(aggregation_type=AggregationType.TOKEN) + + async def test_token_passthrough(self): + """TOKEN mode yields text immediately without buffering.""" + results = [agg async for agg in self.aggregator.aggregate("Hello")] + assert len(results) == 1 + assert results[0].text == "Hello" + assert results[0].type == "token" + + async def test_token_multiple_calls(self): + """Each aggregate call yields its text independently.""" + r1 = [agg async for agg in self.aggregator.aggregate("Hello ")] + r2 = [agg async for agg in self.aggregator.aggregate("world.")] + assert len(r1) == 1 + assert r1[0].text == "Hello " + assert len(r2) == 1 + assert r2[0].text == "world." + + async def test_token_empty_text(self): + """Empty text yields nothing.""" + results = [agg async for agg in self.aggregator.aggregate("")] + assert len(results) == 0 + + async def test_token_flush_returns_none(self): + """Flush returns None in TOKEN mode since nothing is buffered.""" + await self.aggregator.aggregate("Hello").__anext__() + result = await self.aggregator.flush() + assert result is None + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_skip_tags_aggregator.py b/tests/test_skip_tags_aggregator.py index c7fea22c3..882b26e82 100644 --- a/tests/test_skip_tags_aggregator.py +++ b/tests/test_skip_tags_aggregator.py @@ -64,5 +64,60 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase): self.assertEqual(self.aggregator.text.type, "sentence") +class TestSkipTagsAggregatorTokenMode(unittest.IsolatedAsyncioTestCase): + def setUp(self): + from pipecat.utils.text.base_text_aggregator import AggregationType + + self.aggregator = SkipTagsAggregator( + [("", "")], aggregation_type=AggregationType.TOKEN + ) + + async def test_token_no_tags(self): + """No tags: text passes through immediately as TOKEN.""" + results = [agg async for agg in self.aggregator.aggregate("Hello!")] + self.assertEqual(len(results), 1) + self.assertEqual(results[0].text, "Hello!") + self.assertEqual(results[0].type, "token") + + async def test_token_inside_tag_buffers(self): + """Inside a tag, text is buffered until the closing tag is found.""" + results = [agg async for agg in self.aggregator.aggregate("foo@bar")] + # Still inside tag, nothing yielded + self.assertEqual(len(results), 0) + + # Close the tag + results = [agg async for agg in self.aggregator.aggregate("")] + self.assertEqual(len(results), 1) + self.assertEqual(results[0].text, "foo@bar") + self.assertEqual(results[0].type, "token") + + async def test_token_flush_unclosed_tag(self): + """Flush with unclosed tag returns remaining text.""" + async for _ in self.aggregator.aggregate("unclosed"): + pass + result = await self.aggregator.flush() + # TOKEN mode flush returns None (parent behavior) + self.assertIsNone(result) + + async def test_token_text_around_tags(self): + """Simulate word-by-word token delivery with tags.""" + results = [] + # Simulate LLM streaming tokens one at a time + for token in ["Hi ", "", "X", "", " bye"]: + async for agg in self.aggregator.aggregate(token): + results.append(agg) + + self.assertEqual(len(results), 3) + # Text before tag passes through immediately + self.assertEqual(results[0].text, "Hi ") + self.assertEqual(results[0].type, "token") + # Tagged content is buffered until the closing tag, then yielded whole + self.assertEqual(results[1].text, "X") + self.assertEqual(results[1].type, "token") + # Text after tag passes through immediately + self.assertEqual(results[2].text, " bye") + self.assertEqual(results[2].type, "token") + + if __name__ == "__main__": unittest.main()