Merge pull request #3696 from pipecat-ai/mb/streaming-tts-input
Improve streaming TTS input support, add TextAggregationMetricsData
This commit is contained in:
1
changelog/3696.added.md
Normal file
1
changelog/3696.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `TextAggregationMetricsData` metric measuring the time from the first LLM token to the first complete sentence, representing the latency cost of sentence aggregation in the TTS pipeline.
|
||||
1
changelog/3696.changed.md
Normal file
1
changelog/3696.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `text_aggregation_mode` parameter to `TTSService` and all TTS subclasses with a new `TextAggregationMode` enum (`SENTENCE`, `TOKEN`). All text now flows through text aggregators regardless of mode, enabling pattern detection and tag handling in TOKEN mode.
|
||||
1
changelog/3696.deprecated.md
Normal file
1
changelog/3696.deprecated.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Deprecated `aggregate_sentences` parameter on `TTSService` and all TTS subclasses. Use `text_aggregation_mode=TextAggregationMode.SENTENCE` or `text_aggregation_mode=TextAggregationMode.TOKEN` instead.
|
||||
@@ -24,6 +24,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.tts_service import TextAggregationMode
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -56,6 +57,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
# Alternatively, you can use TextAggregationMode.TOKEN to stream tokens instead of
|
||||
# sentencesfor faster response times.
|
||||
# text_aggregation_mode=TextAggregationMode.TOKEN,
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
@@ -14,7 +14,6 @@ and LLM processing.
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -36,6 +35,7 @@ from pipecat.audio.turn.base_turn_analyzer import BaseTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
from pipecat.utils.time import nanoseconds_to_str
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
@@ -393,16 +393,6 @@ class LLMTextFrame(TextFrame):
|
||||
self.includes_inter_frame_spaces = True
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatedTextFrame(TextFrame):
|
||||
"""Text frame representing an aggregation of TextFrames.
|
||||
|
||||
@@ -87,6 +87,19 @@ class TTSUsageMetricsData(MetricsData):
|
||||
value: int
|
||||
|
||||
|
||||
class TextAggregationMetricsData(MetricsData):
|
||||
"""Text aggregation time metrics data.
|
||||
|
||||
Measures the time from the first LLM token to the first complete sentence,
|
||||
representing the latency cost of sentence aggregation in the TTS pipeline.
|
||||
|
||||
Parameters:
|
||||
value: Aggregation time in seconds.
|
||||
"""
|
||||
|
||||
value: float
|
||||
|
||||
|
||||
class TurnMetricsData(MetricsData):
|
||||
"""Metrics data for turn detection predictions.
|
||||
|
||||
|
||||
@@ -485,10 +485,23 @@ class FrameProcessor(BaseObject):
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def start_text_aggregation_metrics(self):
|
||||
"""Start text aggregation time metrics collection."""
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
await self._metrics.start_text_aggregation_metrics()
|
||||
|
||||
async def stop_text_aggregation_metrics(self):
|
||||
"""Stop text aggregation time metrics collection and push results."""
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
frame = await self._metrics.stop_text_aggregation_metrics()
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def stop_all_metrics(self):
|
||||
"""Stop all active metrics collection."""
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
await self.stop_text_aggregation_metrics()
|
||||
|
||||
def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task:
|
||||
"""Create a new task managed by this processor.
|
||||
|
||||
@@ -17,6 +17,7 @@ from pipecat.metrics.metrics import (
|
||||
LLMUsageMetricsData,
|
||||
MetricsData,
|
||||
ProcessingMetricsData,
|
||||
TextAggregationMetricsData,
|
||||
TTFBMetricsData,
|
||||
TTSUsageMetricsData,
|
||||
)
|
||||
@@ -43,6 +44,7 @@ class FrameProcessorMetrics(BaseObject):
|
||||
self._task_manager = None
|
||||
self._start_ttfb_time = 0
|
||||
self._start_processing_time = 0
|
||||
self._start_text_aggregation_time = 0
|
||||
self._last_ttfb_time = 0
|
||||
self._should_report_ttfb = True
|
||||
|
||||
@@ -211,3 +213,24 @@ class FrameProcessorMetrics(BaseObject):
|
||||
)
|
||||
logger.debug(f"{self._processor_name()} usage characters: {characters.value}")
|
||||
return MetricsFrame(data=[characters])
|
||||
|
||||
async def start_text_aggregation_metrics(self):
|
||||
"""Start measuring text aggregation time (first token to first sentence)."""
|
||||
self._start_text_aggregation_time = time.time()
|
||||
|
||||
async def stop_text_aggregation_metrics(self):
|
||||
"""Stop text aggregation measurement and generate metrics frame.
|
||||
|
||||
Returns:
|
||||
MetricsFrame containing text aggregation time, or None if not measuring.
|
||||
"""
|
||||
if self._start_text_aggregation_time == 0:
|
||||
return None
|
||||
|
||||
value = time.time() - self._start_text_aggregation_time
|
||||
logger.debug(f"{self._processor_name()} text aggregation time: {value}")
|
||||
aggregation = TextAggregationMetricsData(
|
||||
processor=self._processor_name(), value=value, model=self._model_name()
|
||||
)
|
||||
self._start_text_aggregation_time = 0
|
||||
return MetricsFrame(data=[aggregation])
|
||||
|
||||
@@ -28,7 +28,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TTSService
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -128,7 +128,8 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
encoding: str = "pcm_s16le",
|
||||
container: str = "raw",
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Async TTS service.
|
||||
@@ -144,13 +145,19 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
encoding: Audio encoding format.
|
||||
container: Audio container format.
|
||||
params: Additional input parameters for voice customization.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
params = params or AsyncAITTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.azure.common import language_to_azure_language
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.services.tts_service import TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -256,7 +256,8 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
voice: str = "en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[AzureBaseTTSService.InputParams] = None,
|
||||
aggregate_sentences: bool = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure streaming TTS service.
|
||||
@@ -267,13 +268,19 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
params: Voice and synthesis parameters configuration.
|
||||
aggregate_sentences: Whether to aggregate sentences before synthesis.
|
||||
**kwargs: Additional arguments passed to the parent TTSService.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
**kwargs: Additional arguments passed to parent WordTTSService.
|
||||
"""
|
||||
params = params or AzureBaseTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_text_frames=False, # We'll push text frames based on word timestamps
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TTSService
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
|
||||
@@ -272,7 +272,8 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
container: str = "raw",
|
||||
params: Optional[InputParams] = None,
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Cartesia TTS service.
|
||||
@@ -292,13 +293,21 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
# Aggregating sentences still gives cleaner-sounding results and fewer
|
||||
# artifacts than streaming one word at a time. On average, waiting for a
|
||||
# full sentence should only "cost" us 15ms or so with GPT-4o or a Llama
|
||||
# 3 model, and it's worth it for the better audio quality.
|
||||
# By default, we aggregate sentences before sending to TTS. This adds
|
||||
# ~200-300ms of latency per sentence (waiting for the sentence-ending
|
||||
# punctuation token from the LLM). Setting
|
||||
# text_aggregation_mode=TextAggregationMode.TOKEN streams tokens
|
||||
# directly, which reduces latency. Streaming quality is good but less
|
||||
# tested than sentence aggregation.
|
||||
# TODO: Consider making TOKEN the default for Cartesia in 1.0.
|
||||
#
|
||||
# We also don't want to automatically push LLM response text frames,
|
||||
# because the context aggregators will add them to the LLM context even
|
||||
@@ -308,6 +317,7 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
params = params or CartesiaTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
@@ -337,7 +347,9 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
# The preferred way of taking advantage of Cartesia SSML Tags is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("<spell>", "</spell>")])
|
||||
self._text_aggregator = SkipTagsAggregator(
|
||||
[("<spell>", "</spell>")], aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
@@ -639,7 +651,10 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
if not self._is_streaming_tokens:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
else:
|
||||
logger.trace(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
|
||||
@@ -47,6 +47,7 @@ from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextTTSService,
|
||||
TextAggregationMode,
|
||||
TTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -365,7 +366,8 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
url: str = "wss://api.elevenlabs.io",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the ElevenLabs TTS service.
|
||||
@@ -377,13 +379,20 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
url: WebSocket URL for ElevenLabs TTS API.
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
params: Additional input parameters for voice customization.
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
# Aggregating sentences still gives cleaner-sounding results and fewer
|
||||
# artifacts than streaming one word at a time. On average, waiting for a
|
||||
# full sentence should only "cost" us 15ms or so with GPT-4o or a Llama
|
||||
# 3 model, and it's worth it for the better audio quality.
|
||||
# By default, we aggregate sentences before sending to TTS. This adds
|
||||
# ~200-300ms of latency per sentence (waiting for the sentence-ending
|
||||
# punctuation token from the LLM). Setting
|
||||
# text_aggregation_mode=TextAggregationMode.TOKEN streams tokens
|
||||
# directly. To use this mode, you must set auto_mode=False. This
|
||||
# eliminates aggregation time, but slows down ElevenLabs.
|
||||
#
|
||||
# We also don't want to automatically push LLM response text frames,
|
||||
# because the context aggregators will add them to the LLM context even
|
||||
@@ -397,6 +406,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
params = params or ElevenLabsTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
@@ -893,7 +903,8 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
base_url: str = "https://api.elevenlabs.io",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the ElevenLabs HTTP TTS service.
|
||||
@@ -906,12 +917,18 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
base_url: Base URL for ElevenLabs HTTP API.
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
params: Additional input parameters for voice customization.
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
params = params or ElevenLabsHttpTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
|
||||
@@ -51,7 +51,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TTSService
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
@@ -509,7 +509,8 @@ class InworldTTSService(AudioContextTTSService):
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "LINEAR16",
|
||||
params: InputParams = None,
|
||||
aggregate_sentences: bool = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
append_trailing_space: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@@ -523,7 +524,12 @@ class InworldTTSService(AudioContextTTSService):
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
encoding: Audio encoding format.
|
||||
params: Input parameters for Inworld WebSocket TTS configuration.
|
||||
aggregate_sentences: Whether to aggregate sentences before synthesis.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
append_trailing_space: Whether to append a trailing space to text before sending to TTS.
|
||||
**kwargs: Additional arguments passed to the parent class.
|
||||
"""
|
||||
@@ -536,6 +542,7 @@ class InworldTTSService(AudioContextTTSService):
|
||||
supports_word_timestamps=True,
|
||||
sample_rate=sample_rate,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
append_trailing_space=append_trailing_space,
|
||||
settings=InworldTTSSettings(
|
||||
model=model,
|
||||
|
||||
@@ -36,7 +36,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -119,7 +119,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
sample_rate: Optional[int] = 22050,
|
||||
encoding: str = "pcm_linear",
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Neuphonic TTS service.
|
||||
@@ -131,13 +132,19 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
sample_rate: Audio sample rate in Hz. Defaults to 22050.
|
||||
encoding: Audio encoding format. Defaults to "pcm_linear".
|
||||
params: Additional input parameters for TTS configuration.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
**kwargs: Additional arguments passed to parent InterruptibleTTSService.
|
||||
"""
|
||||
params = params or NeuphonicTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_stop_frames=True,
|
||||
stop_frame_timeout_s=2.0,
|
||||
sample_rate=sample_rate,
|
||||
|
||||
@@ -35,6 +35,7 @@ from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextTTSService,
|
||||
InterruptibleTTSService,
|
||||
TextAggregationMode,
|
||||
TTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -181,7 +182,8 @@ class RimeTTSService(AudioContextTTSService):
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Rime TTS service.
|
||||
@@ -198,13 +200,19 @@ class RimeTTSService(AudioContextTTSService):
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
params = params or RimeTTSService.InputParams()
|
||||
|
||||
super().__init__(
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
@@ -243,7 +251,9 @@ class RimeTTSService(AudioContextTTSService):
|
||||
# The preferred way of taking advantage of Rime spelling is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
|
||||
self._text_aggregator = SkipTagsAggregator(
|
||||
[("spell(", ")")], aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
@@ -826,7 +836,8 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
audio_format: str = "pcm",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Rime Non-JSON WebSocket TTS service.
|
||||
@@ -839,13 +850,21 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
audio_format: Audio format to use.
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
params: Additional configuration parameters.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead. Set to ``TextAggregationMode.SENTENCE``
|
||||
to aggregate text into sentences before synthesis, or
|
||||
``TextAggregationMode.TOKEN`` to stream tokens directly for lower latency.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
params = params or RimeNonJsonTTSService.InputParams()
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
|
||||
@@ -63,7 +63,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.sarvam._sdk import sdk_headers
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -785,7 +785,8 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
model: str = "bulbul:v2",
|
||||
voice_id: Optional[str] = None,
|
||||
url: str = "wss://api.sarvam.ai/text-to-speech/ws",
|
||||
aggregate_sentences: Optional[bool] = True,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
@@ -799,7 +800,12 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
- "bulbul:v3-beta": Advanced model with temperature control
|
||||
voice_id: Speaker voice ID. If None, uses model-appropriate default.
|
||||
url: WebSocket URL for the TTS backend (default production URL).
|
||||
aggregate_sentences: Merge multiple sentences into one audio chunk (default True).
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead.
|
||||
|
||||
text_aggregation_mode: How to aggregate text before synthesis.
|
||||
sample_rate: Output audio sample rate in Hz (8000, 16000, 22050, 24000).
|
||||
If None, uses model-specific default.
|
||||
params: Optional input parameters to override defaults.
|
||||
@@ -834,6 +840,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
# Initialize parent class first
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_text_frames=True,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
|
||||
@@ -11,6 +11,7 @@ import uuid
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
@@ -72,6 +73,23 @@ class TTSContext:
|
||||
append_to_context: bool = True
|
||||
|
||||
|
||||
class TextAggregationMode(str, Enum):
|
||||
"""Controls how incoming text is aggregated before TTS synthesis.
|
||||
|
||||
Parameters:
|
||||
SENTENCE: Buffer text until sentence boundaries are detected before synthesis.
|
||||
Produces more natural speech but adds latency (~200-300ms per sentence).
|
||||
TOKEN: Stream text tokens directly to TTS as they arrive.
|
||||
Reduces latency but may affect speech quality depending on the TTS provider.
|
||||
"""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
TOKEN = "token"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class TTSService(AIService):
|
||||
"""Base class for text-to-speech services.
|
||||
|
||||
@@ -109,7 +127,8 @@ class TTSService(AIService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aggregate_sentences: bool = True,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
# if True, TTSService will push TextFrames and LLMFullResponseEndFrames,
|
||||
# otherwise subclass must do it
|
||||
push_text_frames: bool = True,
|
||||
@@ -153,7 +172,16 @@ class TTSService(AIService):
|
||||
"""Initialize the TTS service.
|
||||
|
||||
Args:
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
TextAggregationMode.SENTENCE (default) buffers until sentence boundaries,
|
||||
TextAggregationMode.TOKEN streams tokens directly for lower latency.
|
||||
aggregate_sentences: Whether to aggregate text into sentences before synthesis.
|
||||
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``text_aggregation_mode`` instead. Set to ``TextAggregationMode.SENTENCE``
|
||||
to aggregate text into sentences before synthesis, or
|
||||
``TextAggregationMode.TOKEN`` to stream tokens directly for lower latency.
|
||||
|
||||
push_text_frames: Whether to push TextFrames and LLMFullResponseEndFrames.
|
||||
push_stop_frames: Whether to automatically push TTSStoppedFrames.
|
||||
stop_frame_timeout_s: Idle time before pushing TTSStoppedFrame when push_stop_frames is True.
|
||||
@@ -194,7 +222,31 @@ class TTSService(AIService):
|
||||
or TTSSettings(),
|
||||
**kwargs,
|
||||
)
|
||||
self._aggregate_sentences: bool = aggregate_sentences
|
||||
|
||||
# Resolve text_aggregation_mode from the new param or deprecated aggregate_sentences
|
||||
if aggregate_sentences is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'aggregate_sentences' is deprecated. "
|
||||
"Use 'text_aggregation_mode=TextAggregationMode.SENTENCE' or "
|
||||
"'text_aggregation_mode=TextAggregationMode.TOKEN' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if text_aggregation_mode is None:
|
||||
text_aggregation_mode = (
|
||||
TextAggregationMode.SENTENCE
|
||||
if aggregate_sentences
|
||||
else TextAggregationMode.TOKEN
|
||||
)
|
||||
|
||||
if text_aggregation_mode is None:
|
||||
text_aggregation_mode = TextAggregationMode.SENTENCE
|
||||
|
||||
self._text_aggregation_mode: TextAggregationMode = text_aggregation_mode
|
||||
self._push_text_frames: bool = push_text_frames
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
@@ -204,7 +256,9 @@ class TTSService(AIService):
|
||||
self._append_trailing_space: bool = append_trailing_space
|
||||
self._init_sample_rate = sample_rate
|
||||
self._sample_rate = 0
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator(
|
||||
aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
if text_aggregator:
|
||||
import warnings
|
||||
|
||||
@@ -240,6 +294,8 @@ class TTSService(AIService):
|
||||
|
||||
self._processing_text: bool = False
|
||||
self._tts_contexts: Dict[str, TTSContext] = {}
|
||||
self._streamed_text: str = ""
|
||||
self._text_aggregation_metrics_started: bool = False
|
||||
|
||||
# Word timestamp state (active when supports_word_timestamps=True)
|
||||
self._supports_word_timestamps: bool = supports_word_timestamps
|
||||
@@ -253,6 +309,40 @@ class TTSService(AIService):
|
||||
self._register_event_handler("on_connection_error")
|
||||
self._register_event_handler("on_tts_request")
|
||||
|
||||
@property
|
||||
def _is_streaming_tokens(self) -> bool:
|
||||
"""Whether the service is streaming tokens directly without sentence aggregation."""
|
||||
return self._text_aggregation_mode == TextAggregationMode.TOKEN
|
||||
|
||||
async def start_tts_usage_metrics(self, text: str):
|
||||
"""Record TTS usage metrics.
|
||||
|
||||
When streaming tokens, usage metrics are aggregated and reported at
|
||||
flush time instead of per token, so individual calls are skipped.
|
||||
|
||||
Args:
|
||||
text: The text being processed by TTS.
|
||||
"""
|
||||
if self._is_streaming_tokens:
|
||||
return
|
||||
await super().start_tts_usage_metrics(text)
|
||||
|
||||
async def start_text_aggregation_metrics(self):
|
||||
"""Start text aggregation metrics if not already started.
|
||||
|
||||
Only starts the metric once per LLM response. Skipped when streaming
|
||||
tokens since per-token aggregation time is not meaningful.
|
||||
"""
|
||||
if self._is_streaming_tokens or self._text_aggregation_metrics_started:
|
||||
return
|
||||
self._text_aggregation_metrics_started = True
|
||||
await super().start_text_aggregation_metrics()
|
||||
|
||||
async def stop_text_aggregation_metrics(self):
|
||||
"""Stop text aggregation metrics and reset the started flag."""
|
||||
self._text_aggregation_metrics_started = False
|
||||
await super().stop_text_aggregation_metrics()
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""Get the current sample rate for audio output.
|
||||
@@ -511,6 +601,7 @@ class TTSService(AIService):
|
||||
and not isinstance(frame, InterimTranscriptionFrame)
|
||||
and not isinstance(frame, TranscriptionFrame)
|
||||
):
|
||||
await self.start_text_aggregation_metrics()
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
@@ -526,9 +617,17 @@ class TTSService(AIService):
|
||||
|
||||
# Flush any remaining text (including text waiting for lookahead)
|
||||
remaining = await self._text_aggregator.flush()
|
||||
# Stop the aggregation metric (no-op if already stopped on first sentence).
|
||||
await self.stop_text_aggregation_metrics()
|
||||
if remaining:
|
||||
await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type))
|
||||
|
||||
# Log accumulated streamed text and emit aggregated usage metric.
|
||||
if self._streamed_text:
|
||||
logger.debug(f"{self}: Generating TTS [{self._streamed_text}]")
|
||||
await super().start_tts_usage_metrics(self._streamed_text)
|
||||
self._streamed_text = ""
|
||||
|
||||
# Reset aggregator state
|
||||
self._processing_text = False
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
@@ -678,6 +777,8 @@ class TTSService(AIService):
|
||||
await filter.handle_interruption()
|
||||
|
||||
self._llm_response_started = False
|
||||
self._streamed_text = ""
|
||||
self._text_aggregation_metrics_started = False
|
||||
if self._supports_word_timestamps:
|
||||
await self.reset_word_timestamps()
|
||||
|
||||
@@ -690,26 +791,18 @@ class TTSService(AIService):
|
||||
await self.resume_processing_frames()
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: Optional[str] = None
|
||||
includes_inter_frame_spaces: bool = False
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
|
||||
aggregated_by = "token"
|
||||
|
||||
if text:
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
else:
|
||||
async for aggregate in self._text_aggregator.aggregate(frame.text):
|
||||
text = aggregate.text
|
||||
aggregated_by = aggregate.type
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
async for aggregate in self._text_aggregator.aggregate(frame.text):
|
||||
includes_inter_frame_spaces = (
|
||||
frame.includes_inter_frame_spaces
|
||||
if aggregate.type == AggregationType.TOKEN
|
||||
else False
|
||||
)
|
||||
if aggregate.type != AggregationType.TOKEN:
|
||||
# Stop the aggregation metric on the first sentence only.
|
||||
await self.stop_text_aggregation_metrics()
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(aggregate.text, aggregate.type), includes_inter_frame_spaces
|
||||
)
|
||||
|
||||
async def _push_tts_frames(
|
||||
self,
|
||||
@@ -739,7 +832,15 @@ class TTSService(AIService):
|
||||
# or when we received an LLMFullResponseEndFrame
|
||||
self._processing_text = True
|
||||
|
||||
await self.start_processing_metrics()
|
||||
# Accumulate text for a single debug log at flush time when streaming tokens.
|
||||
if self._is_streaming_tokens:
|
||||
self._streamed_text += text
|
||||
|
||||
# Skip per-token processing metrics when streaming. The per-token
|
||||
# processing time is just websocket send overhead (~0.1ms) and not
|
||||
# meaningful. TTFB captures the important timing for streaming TTS.
|
||||
if not self._is_streaming_tokens:
|
||||
await self.start_processing_metrics()
|
||||
|
||||
# Process all filters.
|
||||
for filter in self._text_filters:
|
||||
@@ -747,7 +848,8 @@ class TTSService(AIService):
|
||||
text = await filter.filter(text)
|
||||
|
||||
if not text.strip():
|
||||
await self.stop_processing_metrics()
|
||||
if not self._is_streaming_tokens:
|
||||
await self.stop_processing_metrics()
|
||||
return
|
||||
|
||||
# Create context ID and store metadata
|
||||
@@ -785,7 +887,8 @@ class TTSService(AIService):
|
||||
|
||||
await self.process_generator(self.run_tts(prepared_text, context_id))
|
||||
|
||||
await self.stop_processing_metrics()
|
||||
if not self._is_streaming_tokens:
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
if self._push_text_frames:
|
||||
# In TTS services that support word timestamps, the TTSTextFrames
|
||||
|
||||
@@ -21,6 +21,7 @@ class AggregationType(str, Enum):
|
||||
"""Built-in aggregation strings."""
|
||||
|
||||
SENTENCE = "sentence"
|
||||
TOKEN = "token"
|
||||
WORD = "word"
|
||||
|
||||
def __str__(self):
|
||||
@@ -66,6 +67,25 @@ class BaseTextAggregator(ABC):
|
||||
logic, text manipulation behavior, and state management for interruptions.
|
||||
"""
|
||||
|
||||
def __init__(self, *, aggregation_type: AggregationType = AggregationType.SENTENCE):
|
||||
"""Initialize the base text aggregator.
|
||||
|
||||
Args:
|
||||
aggregation_type: The aggregation strategy to use. SENTENCE buffers
|
||||
text until sentence boundaries are detected, TOKEN passes text
|
||||
through immediately, and WORD buffers until word boundaries.
|
||||
"""
|
||||
self._aggregation_type = AggregationType(aggregation_type)
|
||||
|
||||
@property
|
||||
def aggregation_type(self) -> AggregationType:
|
||||
"""Get the aggregation type for this aggregator.
|
||||
|
||||
Returns:
|
||||
The aggregation type.
|
||||
"""
|
||||
return self._aggregation_type
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def text(self) -> Aggregation:
|
||||
|
||||
@@ -96,8 +96,11 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
|
||||
Creates an empty aggregator with no patterns or handlers registered.
|
||||
Text buffering and pattern detection will begin when text is aggregated.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to SimpleTextAggregator (e.g. aggregation_type).
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self._patterns = {}
|
||||
self._handlers = {}
|
||||
self._last_processed_position = 0 # Track where we last checked for complete patterns
|
||||
@@ -146,7 +149,7 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
|
||||
if type in [AggregationType.SENTENCE, AggregationType.WORD, AggregationType.TOKEN]:
|
||||
raise ValueError(
|
||||
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
|
||||
)
|
||||
@@ -321,6 +324,9 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
and uses the parent's lookahead logic for sentence detection when no
|
||||
patterns are active.
|
||||
|
||||
In TOKEN mode, pattern detection still works but non-pattern text is
|
||||
yielded as TOKEN aggregations instead of waiting for sentence boundaries.
|
||||
|
||||
Args:
|
||||
text: Text to aggregate.
|
||||
|
||||
@@ -370,18 +376,35 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
# boundaries when a pattern begins (e.g., "Here is code <code>..." yields "Here is code")
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
yield PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
agg_type = (
|
||||
AggregationType.TOKEN
|
||||
if self._aggregation_type == AggregationType.TOKEN
|
||||
else AggregationType.SENTENCE
|
||||
)
|
||||
yield PatternMatch(content=result.strip(), type=agg_type, full_match=result)
|
||||
continue
|
||||
|
||||
# Use parent's lookahead logic for sentence detection
|
||||
aggregation = await super()._check_sentence_with_lookahead(char)
|
||||
if aggregation:
|
||||
# Convert to PatternMatch for consistency with return type
|
||||
if self._aggregation_type != AggregationType.TOKEN:
|
||||
# Use parent's lookahead logic for sentence detection
|
||||
aggregation = await super()._check_sentence_with_lookahead(char)
|
||||
if aggregation:
|
||||
# Convert to PatternMatch for consistency with return type
|
||||
yield PatternMatch(
|
||||
content=aggregation.text,
|
||||
type=aggregation.type,
|
||||
full_match=aggregation.text,
|
||||
)
|
||||
|
||||
# In TOKEN mode, yield any accumulated text after processing all chars,
|
||||
# but only if there's no incomplete pattern being buffered.
|
||||
if self._aggregation_type == AggregationType.TOKEN and self._text:
|
||||
if self._match_start_of_pattern(self._text) is None:
|
||||
yield PatternMatch(
|
||||
content=aggregation.text, type=aggregation.type, full_match=aggregation.text
|
||||
content=self._text,
|
||||
type=AggregationType.TOKEN,
|
||||
full_match=self._text,
|
||||
)
|
||||
self._text = ""
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the buffer and pattern state.
|
||||
|
||||
@@ -25,11 +25,15 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
most straightforward implementation of text aggregation for TTS processing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the simple text aggregator.
|
||||
|
||||
Creates an empty text buffer ready to begin accumulating text tokens.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to BaseTextAggregator (e.g. aggregation_type).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._text = ""
|
||||
self._needs_lookahead: bool = False
|
||||
|
||||
@@ -43,19 +47,25 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
return Aggregation(text=self._text.strip(" "), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate text and yield completed sentences.
|
||||
"""Aggregate text and yield completed aggregations.
|
||||
|
||||
Processes the input text character-by-character. When sentence-ending
|
||||
punctuation is detected, it waits for non-whitespace lookahead before
|
||||
calling NLTK. This prevents false positives like "$29." being detected
|
||||
as a sentence when it's actually "$29.95".
|
||||
In SENTENCE mode, processes the input text character-by-character. When
|
||||
sentence-ending punctuation is detected, it waits for non-whitespace
|
||||
lookahead before calling NLTK.
|
||||
|
||||
In TOKEN mode, yields the text immediately without buffering.
|
||||
|
||||
Args:
|
||||
text: Text to aggregate.
|
||||
|
||||
Yields:
|
||||
Complete sentences as Aggregation objects.
|
||||
Aggregation objects (sentences in SENTENCE mode, tokens in TOKEN mode).
|
||||
"""
|
||||
if self._aggregation_type == AggregationType.TOKEN:
|
||||
if text:
|
||||
yield Aggregation(text=text, type=AggregationType.TOKEN)
|
||||
return
|
||||
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
@@ -114,11 +124,15 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
"""Flush any remaining text in the buffer.
|
||||
|
||||
Returns any text remaining in the buffer. This is called at the end
|
||||
of a stream to ensure all text is processed.
|
||||
of a stream to ensure all text is processed. In TOKEN mode, returns
|
||||
None since tokens are yielded immediately.
|
||||
|
||||
Returns:
|
||||
Any remaining text as a sentence, or None if buffer is empty.
|
||||
Any remaining text as a sentence, or None if buffer is empty or in TOKEN mode.
|
||||
"""
|
||||
if self._aggregation_type == AggregationType.TOKEN:
|
||||
return None
|
||||
|
||||
if self._text:
|
||||
# Return whatever we have in the buffer
|
||||
result = self._text
|
||||
|
||||
@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
|
||||
from typing import AsyncIterator, Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
@@ -31,14 +31,15 @@ class SkipTagsAggregator(SimpleTextAggregator):
|
||||
identified and that content within tags is never split at sentence boundaries.
|
||||
"""
|
||||
|
||||
def __init__(self, tags: Sequence[StartEndTags]):
|
||||
def __init__(self, tags: Sequence[StartEndTags], **kwargs):
|
||||
"""Initialize the skip tags aggregator.
|
||||
|
||||
Args:
|
||||
tags: Sequence of StartEndTags objects defining the tag pairs
|
||||
that should prevent sentence boundary detection.
|
||||
**kwargs: Additional arguments passed to SimpleTextAggregator (e.g. aggregation_type).
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self._tags = tags
|
||||
self._current_tag: Optional[StartEndTags] = None
|
||||
self._current_tag_index: int = 0
|
||||
@@ -50,13 +51,33 @@ class SkipTagsAggregator(SimpleTextAggregator):
|
||||
uses the parent's lookahead logic for sentence detection when not
|
||||
inside tags.
|
||||
|
||||
In TOKEN mode, text is passed through immediately unless we're inside
|
||||
a tag, in which case we buffer until the closing tag is found.
|
||||
|
||||
Args:
|
||||
text: Text to aggregate.
|
||||
|
||||
Yields:
|
||||
Aggregation objects containing text up to a sentence boundary,
|
||||
marked as SENTENCE type.
|
||||
marked as SENTENCE type (or TOKEN type in TOKEN mode).
|
||||
"""
|
||||
if self._aggregation_type == AggregationType.TOKEN:
|
||||
# In TOKEN mode, process chars for tag tracking but yield the
|
||||
# full input as a single token when not inside a tag.
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
# Update tag state
|
||||
(self._current_tag, self._current_tag_index) = parse_start_end_tags(
|
||||
self._text, self._tags, self._current_tag, self._current_tag_index
|
||||
)
|
||||
|
||||
# After processing all chars: if not inside a tag, yield accumulated text
|
||||
if not self._current_tag and self._text:
|
||||
yield Aggregation(text=self._text, type=AggregationType.TOKEN)
|
||||
self._text = ""
|
||||
return
|
||||
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
@@ -194,5 +194,66 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
|
||||
class TestPatternPairAggregatorTokenMode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
|
||||
self.aggregator = PatternPairAggregator(aggregation_type=AggregationType.TOKEN)
|
||||
self.handler = AsyncMock()
|
||||
self.aggregator.add_pattern(
|
||||
type="think",
|
||||
start_pattern="<think>",
|
||||
end_pattern="</think>",
|
||||
action=MatchAction.REMOVE,
|
||||
)
|
||||
self.aggregator.on_pattern_match("think", self.handler)
|
||||
|
||||
async def test_token_no_patterns(self):
|
||||
"""Non-pattern text passes through as TOKEN, one per aggregate call."""
|
||||
results = []
|
||||
for token in ["Hello", " world", "."]:
|
||||
async for r in self.aggregator.aggregate(token):
|
||||
results.append(r)
|
||||
|
||||
self.assertEqual(len(results), 3)
|
||||
self.assertEqual(results[0].text, "Hello")
|
||||
self.assertEqual(results[1].text, " world")
|
||||
self.assertEqual(results[2].text, ".")
|
||||
for r in results:
|
||||
self.assertEqual(r.type, "token")
|
||||
|
||||
async def test_token_pattern_detection(self):
|
||||
"""Pattern detection still works with word-by-word token delivery."""
|
||||
results = []
|
||||
for token in ["Hi ", "<think>", "secret", "</think>", " bye"]:
|
||||
async for r in self.aggregator.aggregate(token):
|
||||
results.append(r)
|
||||
|
||||
# Handler called once when the pattern completes
|
||||
self.handler.assert_called_once()
|
||||
call_args = self.handler.call_args[0][0]
|
||||
self.assertEqual(call_args.text, "secret")
|
||||
|
||||
# "Hi " yields before pattern starts, pattern is removed, " bye" yields after
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0].text, "Hi ")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
self.assertEqual(results[1].text, " bye")
|
||||
self.assertEqual(results[1].type, "token")
|
||||
|
||||
async def test_token_incomplete_pattern_buffers(self):
|
||||
"""Incomplete pattern is buffered across calls, not leaked to output."""
|
||||
results = []
|
||||
for token in ["Hi ", "<think>", "partial"]:
|
||||
async for r in self.aggregator.aggregate(token):
|
||||
results.append(r)
|
||||
|
||||
# Only "Hi " should be yielded; "<think>partial" stays buffered
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].text, "Hi ")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
self.handler.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -181,5 +181,39 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
assert result.text == "こんにちは。"
|
||||
|
||||
|
||||
class TestSimpleTextAggregatorTokenMode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
|
||||
self.aggregator = SimpleTextAggregator(aggregation_type=AggregationType.TOKEN)
|
||||
|
||||
async def test_token_passthrough(self):
|
||||
"""TOKEN mode yields text immediately without buffering."""
|
||||
results = [agg async for agg in self.aggregator.aggregate("Hello")]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Hello"
|
||||
assert results[0].type == "token"
|
||||
|
||||
async def test_token_multiple_calls(self):
|
||||
"""Each aggregate call yields its text independently."""
|
||||
r1 = [agg async for agg in self.aggregator.aggregate("Hello ")]
|
||||
r2 = [agg async for agg in self.aggregator.aggregate("world.")]
|
||||
assert len(r1) == 1
|
||||
assert r1[0].text == "Hello "
|
||||
assert len(r2) == 1
|
||||
assert r2[0].text == "world."
|
||||
|
||||
async def test_token_empty_text(self):
|
||||
"""Empty text yields nothing."""
|
||||
results = [agg async for agg in self.aggregator.aggregate("")]
|
||||
assert len(results) == 0
|
||||
|
||||
async def test_token_flush_returns_none(self):
|
||||
"""Flush returns None in TOKEN mode since nothing is buffered."""
|
||||
await self.aggregator.aggregate("Hello").__anext__()
|
||||
result = await self.aggregator.flush()
|
||||
assert result is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -64,5 +64,60 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
|
||||
class TestSkipTagsAggregatorTokenMode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
from pipecat.utils.text.base_text_aggregator import AggregationType
|
||||
|
||||
self.aggregator = SkipTagsAggregator(
|
||||
[("<spell>", "</spell>")], aggregation_type=AggregationType.TOKEN
|
||||
)
|
||||
|
||||
async def test_token_no_tags(self):
|
||||
"""No tags: text passes through immediately as TOKEN."""
|
||||
results = [agg async for agg in self.aggregator.aggregate("Hello!")]
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].text, "Hello!")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
|
||||
async def test_token_inside_tag_buffers(self):
|
||||
"""Inside a tag, text is buffered until the closing tag is found."""
|
||||
results = [agg async for agg in self.aggregator.aggregate("<spell>foo@bar")]
|
||||
# Still inside tag, nothing yielded
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Close the tag
|
||||
results = [agg async for agg in self.aggregator.aggregate("</spell>")]
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].text, "<spell>foo@bar</spell>")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
|
||||
async def test_token_flush_unclosed_tag(self):
|
||||
"""Flush with unclosed tag returns remaining text."""
|
||||
async for _ in self.aggregator.aggregate("<spell>unclosed"):
|
||||
pass
|
||||
result = await self.aggregator.flush()
|
||||
# TOKEN mode flush returns None (parent behavior)
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_token_text_around_tags(self):
|
||||
"""Simulate word-by-word token delivery with tags."""
|
||||
results = []
|
||||
# Simulate LLM streaming tokens one at a time
|
||||
for token in ["Hi ", "<spell>", "X", "</spell>", " bye"]:
|
||||
async for agg in self.aggregator.aggregate(token):
|
||||
results.append(agg)
|
||||
|
||||
self.assertEqual(len(results), 3)
|
||||
# Text before tag passes through immediately
|
||||
self.assertEqual(results[0].text, "Hi ")
|
||||
self.assertEqual(results[0].type, "token")
|
||||
# Tagged content is buffered until the closing tag, then yielded whole
|
||||
self.assertEqual(results[1].text, "<spell>X</spell>")
|
||||
self.assertEqual(results[1].type, "token")
|
||||
# Text after tag passes through immediately
|
||||
self.assertEqual(results[2].text, " bye")
|
||||
self.assertEqual(results[2].type, "token")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user