Refactoring the services using the WordTTSService.

This commit is contained in:
filipi87
2026-02-24 15:48:46 -03:00
parent 081aaa50dc
commit 323477bfa4
9 changed files with 102 additions and 138 deletions

View File

@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.azure.common import language_to_azure_language
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import TTSService, WordTTSService
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -258,7 +258,7 @@ class AzureBaseTTSService:
return escaped_text
class AzureTTSService(WordTTSService, AzureBaseTTSService):
class AzureTTSService(TTSService, AzureBaseTTSService):
"""Azure Cognitive Services streaming TTS service with word timestamps.
Provides real-time text-to-speech synthesis using Azure's WebSocket-based
@@ -286,14 +286,14 @@ class AzureTTSService(WordTTSService, AzureBaseTTSService):
sample_rate: Audio sample rate in Hz. If None, uses service default.
params: Voice and synthesis parameters configuration.
aggregate_sentences: Whether to aggregate sentences before synthesis.
**kwargs: Additional arguments passed to parent WordTTSService.
**kwargs: Additional arguments passed to the parent TTSService.
"""
# Initialize WordTTSService first to set up word timestamp tracking
super().__init__(
aggregate_sentences=aggregate_sentences,
push_text_frames=False, # We'll push text frames based on word timestamps
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
**kwargs,
)

View File

@@ -29,7 +29,7 @@ from pipecat.frames.frames import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, is_given
from pipecat.services.tts_service import AudioContextWordTTSService, TTSService
from pipecat.services.tts_service import AudioContextTTSService, TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
@@ -229,7 +229,7 @@ class CartesiaTTSSettings(TTSSettings):
return super().from_mapping(flat)
class CartesiaTTSService(AudioContextWordTTSService):
class CartesiaTTSService(AudioContextTTSService):
"""Cartesia TTS service with WebSocket streaming and word timestamps.
Provides text-to-speech using Cartesia's streaming WebSocket API.
@@ -311,6 +311,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
text_aggregator=text_aggregator,
**kwargs,

View File

@@ -46,8 +46,8 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, is_given
from pipecat.services.tts_service import (
AudioContextWordTTSService,
WordTTSService,
AudioContextTTSService,
TTSService,
)
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -317,7 +317,7 @@ def calculate_word_times(
return (word_times, new_partial_word, new_partial_word_start_time)
class ElevenLabsTTSService(AudioContextWordTTSService):
class ElevenLabsTTSService(AudioContextTTSService):
"""ElevenLabs WebSocket-based TTS service with word timestamps.
Provides real-time text-to-speech using ElevenLabs' WebSocket streaming API.
@@ -399,6 +399,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
**kwargs,
)
@@ -838,7 +839,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
class ElevenLabsHttpTTSService(WordTTSService):
class ElevenLabsHttpTTSService(TTSService):
"""ElevenLabs HTTP-based TTS service with word timestamps.
Provides text-to-speech using ElevenLabs' HTTP streaming API for simpler,
@@ -903,6 +904,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
**kwargs,
)

View File

@@ -25,7 +25,7 @@ from pipecat.frames.frames import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import AudioContextWordTTSService
from pipecat.services.tts_service import AudioContextTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -51,7 +51,7 @@ class GradiumTTSSettings(TTSSettings):
output_format: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
class GradiumTTSService(AudioContextWordTTSService):
class GradiumTTSService(AudioContextTTSService):
"""Text-to-Speech service using Gradium's websocket API."""
_settings: GradiumTTSSettings
@@ -91,6 +91,7 @@ class GradiumTTSService(AudioContextWordTTSService):
push_stop_frames=True,
push_text_frames=False,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=SAMPLE_RATE,
**kwargs,
)

View File

@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import WordTTSService
from pipecat.services.tts_service import TTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -64,7 +64,7 @@ class HumeTTSSettings(TTSSettings):
trailing_silence: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
class HumeTTSService(WordTTSService):
class HumeTTSService(TTSService):
"""Hume Octave Text-to-Speech service.
Streams PCM audio via Hume's HTTP output streaming (JSON chunks) endpoint
@@ -121,11 +121,11 @@ class HumeTTSService(WordTTSService):
f"Hume TTS streams at {HUME_SAMPLE_RATE} Hz; configured sample_rate={sample_rate}"
)
# WordTTSService sets push_text_frames=False by default, which we want
super().__init__(
sample_rate=sample_rate,
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
**kwargs,
)

View File

@@ -51,7 +51,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import AudioContextWordTTSService, WordTTSService
from pipecat.services.tts_service import AudioContextTTSService, TTSService
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -102,7 +102,7 @@ class InworldTTSSettings(TTSSettings):
return super().from_mapping(flat)
class InworldHttpTTSService(WordTTSService):
class InworldHttpTTSService(TTSService):
"""Inworld AI HTTP-based TTS service.
Supports both streaming and non-streaming modes via the `streaming` parameter.
@@ -153,6 +153,7 @@ class InworldHttpTTSService(WordTTSService):
super().__init__(
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
**kwargs,
)
@@ -467,7 +468,7 @@ class InworldHttpTTSService(WordTTSService):
)
class InworldTTSService(AudioContextWordTTSService):
class InworldTTSService(AudioContextTTSService):
"""Inworld AI WebSocket-based TTS service.
Uses bidirectional WebSocket for lower latency streaming. Supports multiple
@@ -534,6 +535,7 @@ class InworldTTSService(AudioContextWordTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
aggregate_sentences=aggregate_sentences,
append_trailing_space=append_trailing_space,

View File

@@ -26,7 +26,7 @@ from pipecat.frames.frames import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
from pipecat.services.tts_service import AudioContextWordTTSService
from pipecat.services.tts_service import AudioContextTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -58,7 +58,7 @@ class ResembleAITTSSettings(TTSSettings):
}
class ResembleAITTSService(AudioContextWordTTSService):
class ResembleAITTSService(AudioContextTTSService):
"""Resemble AI TTS service with WebSocket streaming and word timestamps.
Provides text-to-speech using Resemble AI's streaming WebSocket API.
@@ -93,6 +93,7 @@ class ResembleAITTSService(AudioContextWordTTSService):
super().__init__(
sample_rate=sample_rate,
reuse_context_id_within_turn=False,
supports_word_timestamps=True,
**kwargs,
)

View File

@@ -33,7 +33,7 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, is_given
from pipecat.services.tts_service import (
AudioContextWordTTSService,
AudioContextTTSService,
InterruptibleTTSService,
TTSService,
)
@@ -130,7 +130,7 @@ class RimeNonJsonTTSSettings(TTSSettings):
_aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"}
class RimeTTSService(AudioContextWordTTSService):
class RimeTTSService(AudioContextTTSService):
"""Text-to-Speech service using Rime's websocket API.
Uses Rime's websocket JSON API to convert text to speech with word-level timing
@@ -207,6 +207,7 @@ class RimeTTSService(AudioContextWordTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
append_trailing_space=True,
sample_rate=sample_rate,
**kwargs,

View File

@@ -128,6 +128,8 @@ class TTSService(AIService):
append_trailing_space: bool = False,
# TTS output sample rate
sample_rate: Optional[int] = None,
# if True, enables word-level timestamp tracking and synchronization
supports_word_timestamps: bool = False,
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
text_aggregator: Optional[BaseTextAggregator] = None,
# Types of text aggregations that should not be spoken.
@@ -160,6 +162,9 @@ class TTSService(AIService):
append_trailing_space: Whether to append a trailing space to text before sending to TTS.
This helps prevent some TTS services from vocalizing trailing punctuation (e.g., "dot").
sample_rate: Output sample rate for generated audio.
supports_word_timestamps: Whether this service supports word-level timestamp tracking.
When True, enables synchronization of audio with spoken words so only spoken words
are added to the conversation context.
text_aggregator: Custom text aggregator for processing incoming text.
.. deprecated:: 0.0.95
@@ -231,6 +236,13 @@ class TTSService(AIService):
self._processing_text: bool = False
self._tts_contexts: Dict[str, TTSContext] = {}
# Word timestamp state (active when supports_word_timestamps=True)
self._supports_word_timestamps: bool = supports_word_timestamps
self._initial_word_timestamp: int = -1
self._initial_word_times: List[Tuple[str, float, Optional[str]]] = []
self._words_task: Optional[asyncio.Task] = None
self._llm_response_started: bool = False
self._register_event_handler("on_connected")
self._register_event_handler("on_disconnected")
self._register_event_handler("on_connection_error")
@@ -366,6 +378,8 @@ class TTSService(AIService):
self._sample_rate = self._init_sample_rate or frame.audio_out_sample_rate
if self._push_stop_frames and not self._stop_frame_task:
self._stop_frame_task = self.create_task(self._stop_frame_handler())
if self._supports_word_timestamps:
self._create_words_task()
async def stop(self, frame: EndFrame):
"""Stop the TTS service.
@@ -377,6 +391,8 @@ class TTSService(AIService):
if self._stop_frame_task:
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
if self._words_task:
await self._stop_words_task()
async def cancel(self, frame: CancelFrame):
"""Cancel the TTS service.
@@ -388,6 +404,8 @@ class TTSService(AIService):
if self._stop_frame_task:
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
if self._words_task:
await self._stop_words_task()
def add_text_transformer(
self,
@@ -492,6 +510,9 @@ class TTSService(AIService):
elif isinstance(frame, InterruptionFrame):
await self._handle_interruption(frame, direction)
await self.push_frame(frame, direction)
elif isinstance(frame, LLMFullResponseStartFrame):
self._llm_response_started = True
await self.push_frame(frame, direction)
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
# We pause processing incoming frames if the LLM response included
# text (it might be that it's only a function calling response). We
@@ -510,6 +531,9 @@ class TTSService(AIService):
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
# Flush any pending audio so the TTS service closes the current context.
if self._supports_word_timestamps:
await self.flush_audio()
elif isinstance(frame, TTSSpeakFrame):
# Store if we were processing text or not so we can set it back.
processing_text = self._processing_text
@@ -648,6 +672,10 @@ class TTSService(AIService):
for filter in self._text_filters:
await filter.handle_interruption()
self._llm_response_started = False
if self._supports_word_timestamps:
await self.reset_word_timestamps()
async def _maybe_pause_frame_processing(self):
if self._processing_text and self._pause_frame_processing:
await self.pause_processing_frames()
@@ -786,25 +814,9 @@ class TTSService(AIService):
await self.push_frame(TTSStoppedFrame())
has_started = False
class WordTTSService(TTSService):
"""Base class for TTS services that support word timestamps.
Word timestamps are useful to synchronize audio with text of the spoken
words. This way only the spoken words are added to the conversation context.
"""
def __init__(self, **kwargs):
"""Initialize the Word TTS service.
Args:
**kwargs: Additional arguments passed to the parent TTSService.
"""
super().__init__(**kwargs)
self._initial_word_timestamp = -1
self._initial_word_times = []
self._words_task = None
self._llm_response_started: bool = False
#
# Word timestamp methods (active when supports_word_timestamps=True)
#
async def start_word_timestamps(self):
"""Start tracking word timestamps from the current time."""
@@ -839,55 +851,9 @@ class WordTTSService(TTSService):
else:
await self._add_word_timestamps(word_times_with_context)
async def start(self, frame: StartFrame):
"""Start the word TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._create_words_task()
async def stop(self, frame: EndFrame):
"""Stop the word TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._stop_words_task()
async def cancel(self, frame: CancelFrame):
"""Cancel the word TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._stop_words_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with word timestamp awareness.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, LLMFullResponseStartFrame):
self._llm_response_started = True
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
self._llm_response_started = False
await self.reset_word_timestamps()
def _create_words_task(self):
if not self._words_task:
self._words_queue = asyncio.Queue()
self._words_queue: asyncio.Queue = asyncio.Queue()
self._words_task = self.create_task(self._words_task_handler())
async def _stop_words_task(self):
@@ -929,6 +895,23 @@ class WordTTSService(TTSService):
self._words_queue.task_done()
class WordTTSService(TTSService):
"""Deprecated. Use TTSService with supports_word_timestamps=True instead.
.. deprecated:: 0.0.104
Word timestamp functionality has been moved to TTSService. Pass
``supports_word_timestamps=True`` to TTSService (or any subclass) instead.
"""
def __init__(self, **kwargs):
"""Initialize the Word TTS service.
Args:
**kwargs: Additional arguments passed to the parent TTSService.
"""
super().__init__(supports_word_timestamps=True, **kwargs)
class WebsocketTTSService(TTSService, WebsocketService):
"""Base class for websocket-based TTS services.
@@ -1001,10 +984,12 @@ class InterruptibleTTSService(WebsocketTTSService):
self._bot_speaking = False
class WebsocketWordTTSService(WordTTSService, WebsocketService):
"""Base class for websocket-based TTS services that support word timestamps.
class WebsocketWordTTSService(WebsocketTTSService):
"""Deprecated. Use WebsocketTTSService with supports_word_timestamps=True instead.
Combines word timestamp functionality with websocket connectivity.
.. deprecated:: 0.0.104
Word timestamp functionality has been moved to TTSService. Pass
``supports_word_timestamps=True`` to WebsocketTTSService instead.
"""
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
@@ -1014,53 +999,26 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
reconnect_on_error: Whether to automatically reconnect on websocket errors.
**kwargs: Additional arguments passed to parent classes.
"""
WordTTSService.__init__(self, **kwargs)
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error_frame(error)
super().__init__(
supports_word_timestamps=True, reconnect_on_error=reconnect_on_error, **kwargs
)
class InterruptibleWordTTSService(WebsocketWordTTSService):
"""Websocket-based TTS service with word timestamps that handles interruptions.
class InterruptibleWordTTSService(InterruptibleTTSService):
"""Deprecated. Use InterruptibleTTSService with supports_word_timestamps=True instead.
For TTS services that support word timestamps but can't correlate generated
audio with requested text. Handles interruptions by reconnecting when needed.
.. deprecated:: 0.0.104
Word timestamp functionality has been moved to TTSService. Pass
``supports_word_timestamps=True`` to InterruptibleTTSService instead.
"""
def __init__(self, **kwargs):
"""Initialize the Interruptible Word TTS service.
Args:
**kwargs: Additional arguments passed to the parent WebsocketWordTTSService.
**kwargs: Additional arguments passed to the parent InterruptibleTTSService.
"""
super().__init__(**kwargs)
# Indicates if the bot is speaking. If the bot is not speaking we don't
# need to reconnect when the user speaks. If the bot is speaking and the
# user interrupts we need to reconnect.
self._bot_speaking = False
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
if self._bot_speaking:
await self._disconnect()
await self._connect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with bot speaking state tracking.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, BotStartedSpeakingFrame):
self._bot_speaking = True
elif isinstance(frame, BotStoppedSpeakingFrame):
self._bot_speaking = False
super().__init__(supports_word_timestamps=True, **kwargs)
class AudioContextTTSService(WebsocketTTSService):
@@ -1299,15 +1257,12 @@ class AudioContextTTSService(WebsocketTTSService):
break
class AudioContextWordTTSService(AudioContextTTSService, WebsocketWordTTSService):
"""Websocket-based TTS service with word timestamps and audio context management.
class AudioContextWordTTSService(AudioContextTTSService):
"""Deprecated. Use AudioContextTTSService with supports_word_timestamps=True instead.
This is a base class for websocket-based TTS services that support word
timestamps and also allow correlating the generated audio with the requested
text through audio contexts.
Combines the audio context management capabilities of AudioContextTTSService
with the word timestamp functionality of WebsocketWordTTSService.
.. deprecated:: 0.0.104
Word timestamp functionality has been moved to TTSService. Pass
``supports_word_timestamps=True`` to AudioContextTTSService instead.
"""
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
@@ -1317,5 +1272,6 @@ class AudioContextWordTTSService(AudioContextTTSService, WebsocketWordTTSService
reconnect_on_error: Whether to automatically reconnect on websocket errors.
**kwargs: Additional arguments passed to parent classes.
"""
AudioContextTTSService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
WebsocketWordTTSService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
super().__init__(
supports_word_timestamps=True, reconnect_on_error=reconnect_on_error, **kwargs
)