Add text_aggregation_mode parameter to TTSService
Move the sentence vs token aggregation concern into text aggregators so all text flows through them regardless of mode. This enables pattern detection and tag handling to work in TOKEN mode. - Add TextAggregationMode enum (SENTENCE, TOKEN) as the user-facing TTS setting, separate from the internal AggregationType - Add TOKEN mode support to Simple, SkipTags, and PatternPair aggregators - Add text_aggregation_mode parameter to TTSService and all TTS subclasses - Deprecate aggregate_sentences in favor of text_aggregation_mode - Merge TTSService._process_text_frame() into a single codepath
This commit is contained in:
1
changelog/3696.changed.md
Normal file
1
changelog/3696.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `text_aggregation_mode` parameter to `TTSService` and all TTS subclasses with a new `TextAggregationMode` enum (`SENTENCE`, `TOKEN`). All text now flows through text aggregators regardless of mode, enabling pattern detection and tag handling in TOKEN mode.
|
||||
1
changelog/3696.deprecated.md
Normal file
1
changelog/3696.deprecated.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Deprecated `aggregate_sentences` parameter on `TTSService` and all TTS subclasses. Use `text_aggregation_mode=TextAggregationMode.SENTENCE` or `text_aggregation_mode=TextAggregationMode.TOKEN` instead.
|
||||
@@ -24,6 +24,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.tts_service import TextAggregationMode
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -56,6 +57,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
text_aggregation_mode=TextAggregationMode.TOKEN,
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
@@ -14,7 +14,6 @@ and LLM processing.
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -36,6 +35,7 @@ from pipecat.audio.turn.base_turn_analyzer import BaseTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
from pipecat.utils.time import nanoseconds_to_str
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
@@ -393,16 +393,6 @@ class LLMTextFrame(TextFrame):
|
||||
self.includes_inter_frame_spaces = True
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatedTextFrame(TextFrame):
|
||||
"""Text frame representing an aggregation of TextFrames.
|
||||
|
||||
@@ -28,7 +28,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TTSService
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -128,7 +128,8 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
encoding: str = "pcm_s16le",
|
||||
container: str = "raw",
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Async TTS service.
|
||||
@@ -144,13 +145,19 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
encoding: Audio encoding format.
|
||||
container: Audio container format.
|
||||
params: Additional input parameters for voice customization.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
params = params or AsyncAITTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.azure.common import language_to_azure_language
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.services.tts_service import TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -256,7 +256,8 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
voice: str = "en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[AzureBaseTTSService.InputParams] = None,
|
||||
aggregate_sentences: bool = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure streaming TTS service.
|
||||
@@ -267,13 +268,19 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
params: Voice and synthesis parameters configuration.
|
||||
aggregate_sentences: Whether to aggregate sentences before synthesis.
|
||||
**kwargs: Additional arguments passed to the parent TTSService.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
**kwargs: Additional arguments passed to parent WordTTSService.
|
||||
"""
|
||||
params = params or AzureBaseTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_text_frames=False, # We'll push text frames based on word timestamps
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TTSService
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
|
||||
@@ -272,7 +272,8 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
container: str = "raw",
|
||||
params: Optional[InputParams] = None,
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Cartesia TTS service.
|
||||
@@ -292,13 +293,19 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
# Aggregating sentences still gives cleaner-sounding results and fewer
|
||||
# artifacts than streaming one word at a time. On average, waiting for a
|
||||
# full sentence should only "cost" us 15ms or so with GPT-4o or a Llama
|
||||
# 3 model, and it's worth it for the better audio quality.
|
||||
# By default, we aggregate sentences before sending to TTS. This adds
|
||||
# ~200-300ms of latency per sentence (waiting for the sentence-ending
|
||||
# punctuation token from the LLM). Setting aggregate_sentences=False
|
||||
# streams tokens directly, which reduces latency. Streaming quality
|
||||
# is good but less tested than sentence aggregation.
|
||||
#
|
||||
# We also don't want to automatically push LLM response text frames,
|
||||
# because the context aggregators will add them to the LLM context even
|
||||
@@ -308,6 +315,7 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
params = params or CartesiaTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
@@ -337,7 +345,9 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
# The preferred way of taking advantage of Cartesia SSML Tags is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("<spell>", "</spell>")])
|
||||
self._text_aggregator = SkipTagsAggregator(
|
||||
[("<spell>", "</spell>")], aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
@@ -639,7 +649,10 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
if not self._is_streaming_tokens:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
else:
|
||||
logger.trace(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
@@ -654,7 +667,9 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
|
||||
try:
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
# Usage metrics are aggregated at flush time when streaming tokens.
|
||||
if not self._is_streaming_tokens:
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -47,6 +47,7 @@ from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextTTSService,
|
||||
TextAggregationMode,
|
||||
TTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -365,7 +366,8 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
url: str = "wss://api.elevenlabs.io",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the ElevenLabs TTS service.
|
||||
@@ -377,13 +379,19 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
url: WebSocket URL for ElevenLabs TTS API.
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
params: Additional input parameters for voice customization.
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
# Aggregating sentences still gives cleaner-sounding results and fewer
|
||||
# artifacts than streaming one word at a time. On average, waiting for a
|
||||
# full sentence should only "cost" us 15ms or so with GPT-4o or a Llama
|
||||
# 3 model, and it's worth it for the better audio quality.
|
||||
# By default, we aggregate sentences before sending to TTS. This adds
|
||||
# ~200-300ms of latency per sentence (waiting for the sentence-ending
|
||||
# punctuation token from the LLM). Setting aggregate_sentences=False
|
||||
# streams tokens directly. To use this mode, you must set auto_mode=False.
|
||||
# This eliminates aggregation time, but slows down ElevenLabs.
|
||||
#
|
||||
# We also don't want to automatically push LLM response text frames,
|
||||
# because the context aggregators will add them to the LLM context even
|
||||
@@ -397,6 +405,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
params = params or ElevenLabsTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
@@ -893,7 +902,8 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
base_url: str = "https://api.elevenlabs.io",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the ElevenLabs HTTP TTS service.
|
||||
@@ -906,12 +916,18 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
base_url: Base URL for ElevenLabs HTTP API.
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
params: Additional input parameters for voice customization.
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
params = params or ElevenLabsHttpTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
|
||||
@@ -51,7 +51,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TTSService
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
@@ -509,7 +509,8 @@ class InworldTTSService(AudioContextTTSService):
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "LINEAR16",
|
||||
params: InputParams = None,
|
||||
aggregate_sentences: bool = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
append_trailing_space: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@@ -523,7 +524,12 @@ class InworldTTSService(AudioContextTTSService):
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
encoding: Audio encoding format.
|
||||
params: Input parameters for Inworld WebSocket TTS configuration.
|
||||
aggregate_sentences: Whether to aggregate sentences before synthesis.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
append_trailing_space: Whether to append a trailing space to text before sending to TTS.
|
||||
**kwargs: Additional arguments passed to the parent class.
|
||||
"""
|
||||
@@ -536,6 +542,7 @@ class InworldTTSService(AudioContextTTSService):
|
||||
supports_word_timestamps=True,
|
||||
sample_rate=sample_rate,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
append_trailing_space=append_trailing_space,
|
||||
settings=InworldTTSSettings(
|
||||
model=model,
|
||||
|
||||
@@ -36,7 +36,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -119,7 +119,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
sample_rate: Optional[int] = 22050,
|
||||
encoding: str = "pcm_linear",
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Neuphonic TTS service.
|
||||
@@ -131,13 +132,19 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
sample_rate: Audio sample rate in Hz. Defaults to 22050.
|
||||
encoding: Audio encoding format. Defaults to "pcm_linear".
|
||||
params: Additional input parameters for TTS configuration.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
**kwargs: Additional arguments passed to parent InterruptibleTTSService.
|
||||
"""
|
||||
params = params or NeuphonicTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_stop_frames=True,
|
||||
stop_frame_timeout_s=2.0,
|
||||
sample_rate=sample_rate,
|
||||
|
||||
@@ -35,6 +35,7 @@ from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextTTSService,
|
||||
InterruptibleTTSService,
|
||||
TextAggregationMode,
|
||||
TTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -181,7 +182,8 @@ class RimeTTSService(AudioContextTTSService):
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Rime TTS service.
|
||||
@@ -198,13 +200,19 @@ class RimeTTSService(AudioContextTTSService):
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
params = params or RimeTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
@@ -243,7 +251,9 @@ class RimeTTSService(AudioContextTTSService):
|
||||
# The preferred way of taking advantage of Rime spelling is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
|
||||
self._text_aggregator = SkipTagsAggregator(
|
||||
[("spell(", ")")], aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
@@ -826,7 +836,8 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
audio_format: str = "pcm",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Rime Non-JSON WebSocket TTS service.
|
||||
@@ -839,13 +850,21 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
audio_format: Audio format to use.
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
params: Additional configuration parameters.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead. Set to ``TextAggregationMode.SENTENCE``
|
||||
to aggregate text into sentences before synthesis, or
|
||||
``TextAggregationMode.TOKEN`` to stream tokens directly for lower latency.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
params = params or RimeNonJsonTTSService.InputParams()
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
|
||||
@@ -63,7 +63,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.sarvam._sdk import sdk_headers
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -785,7 +785,8 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
model: str = "bulbul:v2",
|
||||
voice_id: Optional[str] = None,
|
||||
url: str = "wss://api.sarvam.ai/text-to-speech/ws",
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
@@ -799,7 +800,12 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
- "bulbul:v3-beta": Advanced model with temperature control
|
||||
voice_id: Speaker voice ID. If None, uses model-appropriate default.
|
||||
url: WebSocket URL for the TTS backend (default production URL).
|
||||
aggregate_sentences: Merge multiple sentences into one audio chunk (default True).
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
sample_rate: Output audio sample rate in Hz (8000, 16000, 22050, 24000).
|
||||
If None, uses model-specific default.
|
||||
params: Optional input parameters to override defaults.
|
||||
@@ -834,6 +840,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
# Initialize parent class first
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_text_frames=True,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
|
||||
@@ -11,6 +11,7 @@ import uuid
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
@@ -72,6 +73,23 @@ class TTSContext:
|
||||
append_to_context: bool = True
|
||||
|
||||
|
||||
class TextAggregationMode(str, Enum):
|
||||
"""Controls how incoming text is aggregated before TTS synthesis.
|
||||
|
||||
Parameters:
|
||||
SENTENCE: Buffer text until sentence boundaries are detected before synthesis.
|
||||
Produces more natural speech but adds latency (~200-300ms per sentence).
|
||||
TOKEN: Stream text tokens directly to TTS as they arrive.
|
||||
Reduces latency but may affect speech quality depending on the TTS provider.
|
||||
"""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
TOKEN = "token"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class TTSService(AIService):
|
||||
"""Base class for text-to-speech services.
|
||||
|
||||
@@ -109,7 +127,8 @@ class TTSService(AIService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aggregate_sentences: bool = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
# if True, TTSService will push TextFrames and LLMFullResponseEndFrames,
|
||||
# otherwise subclass must do it
|
||||
push_text_frames: bool = True,
|
||||
@@ -153,7 +172,16 @@ class TTSService(AIService):
|
||||
"""Initialize the TTS service.
|
||||
|
||||
Args:
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
TextAggregationMode.SENTENCE (default) buffers until sentence boundaries,
|
||||
TextAggregationMode.TOKEN streams tokens directly for lower latency.
|
||||
aggregate_sentences: Whether to aggregate text into sentences before synthesis.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead. Set to ``TextAggregationMode.SENTENCE``
|
||||
to aggregate text into sentences before synthesis, or
|
||||
``TextAggregationMode.TOKEN`` to stream tokens directly for lower latency.
|
||||
|
||||
push_text_frames: Whether to push TextFrames and LLMFullResponseEndFrames.
|
||||
push_stop_frames: Whether to automatically push TTSStoppedFrames.
|
||||
stop_frame_timeout_s: Idle time before pushing TTSStoppedFrame when push_stop_frames is True.
|
||||
@@ -194,7 +222,33 @@ class TTSService(AIService):
|
||||
or TTSSettings(),
|
||||
**kwargs,
|
||||
)
|
||||
self._aggregate_sentences: bool = aggregate_sentences
|
||||
|
||||
# Resolve text_aggregation_mode from the new param or deprecated aggregate_sentences
|
||||
if aggregate_sentences is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'aggregate_sentences' is deprecated. "
|
||||
"Use 'text_aggregation_mode=TextAggregationMode.SENTENCE' or "
|
||||
"'text_aggregation_mode=TextAggregationMode.TOKEN' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if text_aggregation_mode is None:
|
||||
text_aggregation_mode = (
|
||||
TextAggregationMode.SENTENCE
|
||||
if aggregate_sentences
|
||||
else TextAggregationMode.TOKEN
|
||||
)
|
||||
|
||||
if text_aggregation_mode is None:
|
||||
text_aggregation_mode = TextAggregationMode.SENTENCE
|
||||
|
||||
self._text_aggregation_mode: TextAggregationMode = text_aggregation_mode
|
||||
# Keep for backward compat with subclasses that read self._aggregate_sentences
|
||||
self._aggregate_sentences: bool = text_aggregation_mode != TextAggregationMode.TOKEN
|
||||
self._push_text_frames: bool = push_text_frames
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
@@ -204,7 +258,9 @@ class TTSService(AIService):
|
||||
self._append_trailing_space: bool = append_trailing_space
|
||||
self._init_sample_rate = sample_rate
|
||||
self._sample_rate = 0
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator(
|
||||
aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
if text_aggregator:
|
||||
import warnings
|
||||
|
||||
@@ -240,6 +296,8 @@ class TTSService(AIService):
|
||||
|
||||
self._processing_text: bool = False
|
||||
self._tts_contexts: Dict[str, TTSContext] = {}
|
||||
self._streaming_text_log: str = ""
|
||||
self._aggregation_logged: bool = False
|
||||
|
||||
# Word timestamp state (active when supports_word_timestamps=True)
|
||||
self._supports_word_timestamps: bool = supports_word_timestamps
|
||||
@@ -253,6 +311,11 @@ class TTSService(AIService):
|
||||
self._register_event_handler("on_connection_error")
|
||||
self._register_event_handler("on_tts_request")
|
||||
|
||||
@property
|
||||
def _is_streaming_tokens(self) -> bool:
|
||||
"""Whether the service is streaming tokens directly without sentence aggregation."""
|
||||
return self._text_aggregation_mode == TextAggregationMode.TOKEN
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""Get the current sample rate for audio output.
|
||||
@@ -511,6 +574,9 @@ class TTSService(AIService):
|
||||
and not isinstance(frame, InterimTranscriptionFrame)
|
||||
and not isinstance(frame, TranscriptionFrame)
|
||||
):
|
||||
if not self._is_streaming_tokens and not self._aggregation_logged:
|
||||
await self.start_text_aggregation_metrics()
|
||||
self._aggregation_logged = True
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
@@ -527,8 +593,18 @@ class TTSService(AIService):
|
||||
# Flush any remaining text (including text waiting for lookahead)
|
||||
remaining = await self._text_aggregator.flush()
|
||||
if remaining:
|
||||
# If this is the first (and only) sentence, stop the aggregation metric.
|
||||
await self.stop_text_aggregation_metrics()
|
||||
await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type))
|
||||
|
||||
self._aggregation_logged = False
|
||||
|
||||
# Log accumulated streamed text and emit aggregated usage metric.
|
||||
if self._streaming_text_log:
|
||||
logger.debug(f"{self}: Generating TTS [{self._streaming_text_log}]")
|
||||
await self.start_tts_usage_metrics(self._streaming_text_log)
|
||||
self._streaming_text_log = ""
|
||||
|
||||
# Reset aggregator state
|
||||
self._processing_text = False
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
@@ -690,26 +766,18 @@ class TTSService(AIService):
|
||||
await self.resume_processing_frames()
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: Optional[str] = None
|
||||
includes_inter_frame_spaces: bool = False
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
|
||||
aggregated_by = "token"
|
||||
|
||||
if text:
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
else:
|
||||
async for aggregate in self._text_aggregator.aggregate(frame.text):
|
||||
text = aggregate.text
|
||||
aggregated_by = aggregate.type
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
async for aggregate in self._text_aggregator.aggregate(frame.text):
|
||||
includes_inter_frame_spaces = (
|
||||
frame.includes_inter_frame_spaces
|
||||
if aggregate.type == AggregationType.TOKEN
|
||||
else False
|
||||
)
|
||||
if aggregate.type != AggregationType.TOKEN:
|
||||
# Stop the aggregation metric on the first sentence only.
|
||||
await self.stop_text_aggregation_metrics()
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(aggregate.text, aggregate.type), includes_inter_frame_spaces
|
||||
)
|
||||
|
||||
async def _push_tts_frames(
|
||||
self,
|
||||
@@ -739,7 +807,15 @@ class TTSService(AIService):
|
||||
# or when we received an LLMFullResponseEndFrame
|
||||
self._processing_text = True
|
||||
|
||||
await self.start_processing_metrics()
|
||||
# Accumulate text for a single debug log at flush time when streaming tokens.
|
||||
if self._is_streaming_tokens:
|
||||
self._streaming_text_log += text
|
||||
|
||||
# Skip per-token processing metrics when streaming. The per-token
|
||||
# processing time is just websocket send overhead (~0.1ms) and not
|
||||
# meaningful. TTFB captures the important timing for streaming TTS.
|
||||
if not self._is_streaming_tokens:
|
||||
await self.start_processing_metrics()
|
||||
|
||||
# Process all filters.
|
||||
for filter in self._text_filters:
|
||||
@@ -747,7 +823,8 @@ class TTSService(AIService):
|
||||
text = await filter.filter(text)
|
||||
|
||||
if not text.strip():
|
||||
await self.stop_processing_metrics()
|
||||
if not self._is_streaming_tokens:
|
||||
await self.stop_processing_metrics()
|
||||
return
|
||||
|
||||
# Create context ID and store metadata
|
||||
@@ -785,7 +862,8 @@ class TTSService(AIService):
|
||||
|
||||
await self.process_generator(self.run_tts(prepared_text, context_id))
|
||||
|
||||
await self.stop_processing_metrics()
|
||||
if not self._is_streaming_tokens:
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
if self._push_text_frames:
|
||||
# In TTS services that support word timestamps, the TTSTextFrames
|
||||
|
||||
@@ -21,6 +21,7 @@ class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
TOKEN = "token"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
@@ -66,6 +67,25 @@ class BaseTextAggregator(ABC):
|
||||
logic, text manipulation behavior, and state management for interruptions.
|
||||
"""
|
||||
|
||||
def __init__(self, *, aggregation_type: AggregationType = AggregationType.SENTENCE):
|
||||
"""Initialize the base text aggregator.
|
||||
|
||||
Args:
|
||||
aggregation_type: The aggregation strategy to use. SENTENCE buffers
|
||||
text until sentence boundaries are detected, TOKEN passes text
|
||||
through immediately, and WORD buffers until word boundaries.
|
||||
"""
|
||||
self._aggregation_type = AggregationType(aggregation_type)
|
||||
|
||||
@property
|
||||
def aggregation_type(self) -> AggregationType:
|
||||
"""Get the aggregation type for this aggregator.
|
||||
|
||||
Returns:
|
||||
The aggregation type.
|
||||
"""
|
||||
return self._aggregation_type
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def text(self) -> Aggregation:
|
||||
|
||||
@@ -96,8 +96,11 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
|
||||
Creates an empty aggregator with no patterns or handlers registered.
|
||||
Text buffering and pattern detection will begin when text is aggregated.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to SimpleTextAggregator (e.g. aggregation_type).
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self._patterns = {}
|
||||
self._handlers = {}
|
||||
self._last_processed_position = 0 # Track where we last checked for complete patterns
|
||||
@@ -146,7 +149,7 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
|
||||
if type in [AggregationType.SENTENCE, AggregationType.WORD, AggregationType.TOKEN]:
|
||||
raise ValueError(
|
||||
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
|
||||
)
|
||||
@@ -321,6 +324,9 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
and uses the parent's lookahead logic for sentence detection when no
|
||||
patterns are active.
|
||||
|
||||
In TOKEN mode, pattern detection still works but non-pattern text is
|
||||
yielded as TOKEN aggregations instead of waiting for sentence boundaries.
|
||||
|
||||
Args:
|
||||
text: Text to aggregate.
|
||||
|
||||
@@ -370,18 +376,35 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
# boundaries when a pattern begins (e.g., "Here is code <code>..." yields "Here is code")
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
yield PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
agg_type = (
|
||||
AggregationType.TOKEN
|
||||
if self._aggregation_type == AggregationType.TOKEN
|
||||
else AggregationType.SENTENCE
|
||||
)
|
||||
yield PatternMatch(content=result.strip(), type=agg_type, full_match=result)
|
||||
continue
|
||||
|
||||
# Use parent's lookahead logic for sentence detection
|
||||
aggregation = await super()._check_sentence_with_lookahead(char)
|
||||
if aggregation:
|
||||
# Convert to PatternMatch for consistency with return type
|
||||
if self._aggregation_type != AggregationType.TOKEN:
|
||||
# Use parent's lookahead logic for sentence detection
|
||||
aggregation = await super()._check_sentence_with_lookahead(char)
|
||||
if aggregation:
|
||||
# Convert to PatternMatch for consistency with return type
|
||||
yield PatternMatch(
|
||||
content=aggregation.text,
|
||||
type=aggregation.type,
|
||||
full_match=aggregation.text,
|
||||
)
|
||||
|
||||
# In TOKEN mode, yield any accumulated text after processing all chars,
|
||||
# but only if there's no incomplete pattern being buffered.
|
||||
if self._aggregation_type == AggregationType.TOKEN and self._text:
|
||||
if self._match_start_of_pattern(self._text) is None:
|
||||
yield PatternMatch(
|
||||
content=aggregation.text, type=aggregation.type, full_match=aggregation.text
|
||||
content=self._text,
|
||||
type=AggregationType.TOKEN,
|
||||
full_match=self._text,
|
||||
)
|
||||
self._text = ""
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the buffer and pattern state.
|
||||
|
||||
@@ -25,11 +25,15 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
most straightforward implementation of text aggregation for TTS processing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the simple text aggregator.
|
||||
|
||||
Creates an empty text buffer ready to begin accumulating text tokens.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to BaseTextAggregator (e.g. aggregation_type).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._text = ""
|
||||
self._needs_lookahead: bool = False
|
||||
|
||||
@@ -43,19 +47,25 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
return Aggregation(text=self._text.strip(" "), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate text and yield completed sentences.
|
||||
"""Aggregate text and yield completed aggregations.
|
||||
|
||||
Processes the input text character-by-character. When sentence-ending
|
||||
punctuation is detected, it waits for non-whitespace lookahead before
|
||||
calling NLTK. This prevents false positives like "$29." being detected
|
||||
as a sentence when it's actually "$29.95".
|
||||
In SENTENCE mode, processes the input text character-by-character. When
|
||||
sentence-ending punctuation is detected, it waits for non-whitespace
|
||||
lookahead before calling NLTK.
|
||||
|
||||
In TOKEN mode, yields the text immediately without buffering.
|
||||
|
||||
Args:
|
||||
text: Text to aggregate.
|
||||
|
||||
Yields:
|
||||
Complete sentences as Aggregation objects.
|
||||
Aggregation objects (sentences in SENTENCE mode, tokens in TOKEN mode).
|
||||
"""
|
||||
if self._aggregation_type == AggregationType.TOKEN:
|
||||
if text:
|
||||
yield Aggregation(text=text, type=AggregationType.TOKEN)
|
||||
return
|
||||
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
@@ -114,11 +124,15 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
"""Flush any remaining text in the buffer.
|
||||
|
||||
Returns any text remaining in the buffer. This is called at the end
|
||||
of a stream to ensure all text is processed.
|
||||
of a stream to ensure all text is processed. In TOKEN mode, returns
|
||||
None since tokens are yielded immediately.
|
||||
|
||||
Returns:
|
||||
Any remaining text as a sentence, or None if buffer is empty.
|
||||
Any remaining text as a sentence, or None if buffer is empty or in TOKEN mode.
|
||||
"""
|
||||
if self._aggregation_type == AggregationType.TOKEN:
|
||||
return None
|
||||
|
||||
if self._text:
|
||||
# Return whatever we have in the buffer
|
||||
result = self._text
|
||||
|
||||
@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
|
||||
from typing import AsyncIterator, Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
@@ -31,14 +31,15 @@ class SkipTagsAggregator(SimpleTextAggregator):
|
||||
identified and that content within tags is never split at sentence boundaries.
|
||||
"""
|
||||
|
||||
def __init__(self, tags: Sequence[StartEndTags]):
|
||||
def __init__(self, tags: Sequence[StartEndTags], **kwargs):
|
||||
"""Initialize the skip tags aggregator.
|
||||
|
||||
Args:
|
||||
tags: Sequence of StartEndTags objects defining the tag pairs
|
||||
that should prevent sentence boundary detection.
|
||||
**kwargs: Additional arguments passed to SimpleTextAggregator (e.g. aggregation_type).
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self._tags = tags
|
||||
self._current_tag: Optional[StartEndTags] = None
|
||||
self._current_tag_index: int = 0
|
||||
@@ -50,13 +51,33 @@ class SkipTagsAggregator(SimpleTextAggregator):
|
||||
uses the parent's lookahead logic for sentence detection when not
|
||||
inside tags.
|
||||
|
||||
In TOKEN mode, text is passed through immediately unless we're inside
|
||||
a tag, in which case we buffer until the closing tag is found.
|
||||
|
||||
Args:
|
||||
text: Text to aggregate.
|
||||
|
||||
Yields:
|
||||
Aggregation objects containing text up to a sentence boundary,
|
||||
marked as SENTENCE type.
|
||||
marked as SENTENCE type (or TOKEN type in TOKEN mode).
|
||||
"""
|
||||
if self._aggregation_type == AggregationType.TOKEN:
|
||||
# In TOKEN mode, process chars for tag tracking but yield the
|
||||
# full input as a single token when not inside a tag.
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
# Update tag state
|
||||
(self._current_tag, self._current_tag_index) = parse_start_end_tags(
|
||||
self._text, self._tags, self._current_tag, self._current_tag_index
|
||||
)
|
||||
|
||||
# After processing all chars: if not inside a tag, yield accumulated text
|
||||
if not self._current_tag and self._text:
|
||||
yield Aggregation(text=self._text, type=AggregationType.TOKEN)
|
||||
self._text = ""
|
||||
return
|
||||
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
@@ -194,5 +194,66 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
|
||||
class TestPatternPairAggregatorTokenMode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
|
||||
self.aggregator = PatternPairAggregator(aggregation_type=AggregationType.TOKEN)
|
||||
self.handler = AsyncMock()
|
||||
self.aggregator.add_pattern(
|
||||
type="think",
|
||||
start_pattern="<think>",
|
||||
end_pattern="</think>",
|
||||
action=MatchAction.REMOVE,
|
||||
)
|
||||
self.aggregator.on_pattern_match("think", self.handler)
|
||||
|
||||
async def test_token_no_patterns(self):
|
||||
"""Non-pattern text passes through as TOKEN, one per aggregate call."""
|
||||
results = []
|
||||
for token in ["Hello", " world", "."]:
|
||||
async for r in self.aggregator.aggregate(token):
|
||||
results.append(r)
|
||||
|
||||
self.assertEqual(len(results), 3)
|
||||
self.assertEqual(results[0].text, "Hello")
|
||||
self.assertEqual(results[1].text, " world")
|
||||
self.assertEqual(results[2].text, ".")
|
||||
for r in results:
|
||||
self.assertEqual(r.type, "token")
|
||||
|
||||
async def test_token_pattern_detection(self):
|
||||
"""Pattern detection still works with word-by-word token delivery."""
|
||||
results = []
|
||||
for token in ["Hi ", "<think>", "secret", "</think>", " bye"]:
|
||||
async for r in self.aggregator.aggregate(token):
|
||||
results.append(r)
|
||||
|
||||
# Handler called once when the pattern completes
|
||||
self.handler.assert_called_once()
|
||||
call_args = self.handler.call_args[0][0]
|
||||
self.assertEqual(call_args.text, "secret")
|
||||
|
||||
# "Hi " yields before pattern starts, pattern is removed, " bye" yields after
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0].text, "Hi ")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
self.assertEqual(results[1].text, " bye")
|
||||
self.assertEqual(results[1].type, "token")
|
||||
|
||||
async def test_token_incomplete_pattern_buffers(self):
|
||||
"""Incomplete pattern is buffered across calls, not leaked to output."""
|
||||
results = []
|
||||
for token in ["Hi ", "<think>", "partial"]:
|
||||
async for r in self.aggregator.aggregate(token):
|
||||
results.append(r)
|
||||
|
||||
# Only "Hi " should be yielded; "<think>partial" stays buffered
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].text, "Hi ")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
self.handler.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -181,5 +181,39 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
assert result.text == "こんにちは。"
|
||||
|
||||
|
||||
class TestSimpleTextAggregatorTokenMode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
|
||||
self.aggregator = SimpleTextAggregator(aggregation_type=AggregationType.TOKEN)
|
||||
|
||||
async def test_token_passthrough(self):
|
||||
"""TOKEN mode yields text immediately without buffering."""
|
||||
results = [agg async for agg in self.aggregator.aggregate("Hello")]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Hello"
|
||||
assert results[0].type == "token"
|
||||
|
||||
async def test_token_multiple_calls(self):
|
||||
"""Each aggregate call yields its text independently."""
|
||||
r1 = [agg async for agg in self.aggregator.aggregate("Hello ")]
|
||||
r2 = [agg async for agg in self.aggregator.aggregate("world.")]
|
||||
assert len(r1) == 1
|
||||
assert r1[0].text == "Hello "
|
||||
assert len(r2) == 1
|
||||
assert r2[0].text == "world."
|
||||
|
||||
async def test_token_empty_text(self):
|
||||
"""Empty text yields nothing."""
|
||||
results = [agg async for agg in self.aggregator.aggregate("")]
|
||||
assert len(results) == 0
|
||||
|
||||
async def test_token_flush_returns_none(self):
|
||||
"""Flush returns None in TOKEN mode since nothing is buffered."""
|
||||
await self.aggregator.aggregate("Hello").__anext__()
|
||||
result = await self.aggregator.flush()
|
||||
assert result is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -64,5 +64,60 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
|
||||
class TestSkipTagsAggregatorTokenMode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
|
||||
self.aggregator = SkipTagsAggregator(
|
||||
[("<spell>", "</spell>")], aggregation_type=AggregationType.TOKEN
|
||||
)
|
||||
|
||||
async def test_token_no_tags(self):
|
||||
"""No tags: text passes through immediately as TOKEN."""
|
||||
results = [agg async for agg in self.aggregator.aggregate("Hello!")]
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].text, "Hello!")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
|
||||
async def test_token_inside_tag_buffers(self):
|
||||
"""Inside a tag, text is buffered until the closing tag is found."""
|
||||
results = [agg async for agg in self.aggregator.aggregate("<spell>foo@bar")]
|
||||
# Still inside tag, nothing yielded
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Close the tag
|
||||
results = [agg async for agg in self.aggregator.aggregate("</spell>")]
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].text, "<spell>foo@bar</spell>")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
|
||||
async def test_token_flush_unclosed_tag(self):
|
||||
"""Flush with unclosed tag returns remaining text."""
|
||||
async for _ in self.aggregator.aggregate("<spell>unclosed"):
|
||||
pass
|
||||
result = await self.aggregator.flush()
|
||||
# TOKEN mode flush returns None (parent behavior)
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_token_text_around_tags(self):
|
||||
"""Simulate word-by-word token delivery with tags."""
|
||||
results = []
|
||||
# Simulate LLM streaming tokens one at a time
|
||||
for token in ["Hi ", "<spell>", "X", "</spell>", " bye"]:
|
||||
async for agg in self.aggregator.aggregate(token):
|
||||
results.append(agg)
|
||||
|
||||
self.assertEqual(len(results), 3)
|
||||
# Text before tag passes through immediately
|
||||
self.assertEqual(results[0].text, "Hi ")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
# Tagged content is buffered until the closing tag, then yielded whole
|
||||
self.assertEqual(results[1].text, "<spell>X</spell>")
|
||||
self.assertEqual(results[1].type, "token")
|
||||
# Text after tag passes through immediately
|
||||
self.assertEqual(results[2].text, " bye")
|
||||
self.assertEqual(results[2].type, "token")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user