Refactored all 25+ TTS service implementations to use the new push_start_frame=True pattern

This commit is contained in:
filipi87
2026-03-06 16:15:59 -03:00
parent 24430d8d45
commit 88ff7c451b
26 changed files with 182 additions and 382 deletions

View File

@@ -23,12 +23,11 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -80,7 +79,7 @@ class AsyncAITTSSettings(TTSSettings):
pass
class AsyncAITTSService(AudioContextTTSService):
class AsyncAITTSService(WebsocketTTSService):
"""Async TTS service with WebSocket streaming.
Provides text-to-speech using Async's streaming WebSocket API.
@@ -183,8 +182,9 @@ class AsyncAITTSService(AudioContextTTSService):
aggregate_sentences=aggregate_sentences,
text_aggregation_mode=text_aggregation_mode,
pause_frame_processing=True,
push_stop_frames=True,
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -340,13 +340,18 @@ class AsyncAITTSService(AudioContextTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
"""Flush any pending audio."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio.
Args:
context_id: The specific context to flush. If None, falls back to the
currently active context.
"""
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
msg = self._build_msg(text=" ", context_id=context_id, force=True)
msg = self._build_msg(text=" ", context_id=flush_id, force=True)
await self._websocket.send(msg)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -459,12 +464,6 @@ class AsyncAITTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
msg = self._build_msg(text=text, force=True, context_id=context_id)
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
@@ -574,6 +573,8 @@ class AsyncAIHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -632,7 +633,7 @@ class AsyncAIHttpTTSService(TTSService):
try:
voice_config = {"mode": "id", "id": self._settings.voice}
await self.start_ttfb_metrics()
payload = {
"model_id": self._settings.model,
"transcript": text,
@@ -644,7 +645,7 @@ class AsyncAIHttpTTSService(TTSService):
},
"language": self._settings.language,
}
yield TTSStartedFrame(context_id=context_id)
headers = {
"version": self._api_version,
"x-api-key": self._api_key,
@@ -682,4 +683,3 @@ class AsyncAIHttpTTSService(TTSService):
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -247,6 +245,8 @@ class AWSPollyTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -329,8 +329,6 @@ class AWSPollyTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Construct the parameters dictionary
ssml = self._construct_ssml(text)
@@ -362,8 +360,6 @@ class AWSPollyTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
for i in range(0, len(audio_data), CHUNK_SIZE):
@@ -373,14 +369,10 @@ class AWSPollyTTSService(TTSService):
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
yield frame
yield TTSStoppedFrame(context_id=context_id)
except (BotoCoreError, ClientError) as error:
error_message = f"AWS Polly TTS error: {str(error)}"
yield ErrorFrame(error=error_message)
finally:
yield TTSStoppedFrame(context_id=context_id)
class PollyTTSService(AWSPollyTTSService):
"""Deprecated alias for AWSPollyTTSService.

View File

@@ -21,7 +21,6 @@ from pipecat.frames.frames import (
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
@@ -331,8 +330,8 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
text_aggregation_mode=text_aggregation_mode,
push_text_frames=False, # We'll push text frames based on word timestamps
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -346,7 +345,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
self._audio_queue = asyncio.Queue()
self._word_boundary_queue = asyncio.Queue()
self._word_processor_task = None
self._first_chunk = True
self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds
self._current_sentence_base_offset: float = 0.0 # Base offset for current sentence
self._current_sentence_duration: float = 0.0 # Duration from Azure callback
@@ -619,7 +617,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
def _reset_state(self):
"""Reset TTS state between turns."""
self._first_chunk = True
self._cumulative_audio_offset = 0.0
self._current_sentence_base_offset = 0.0
self._current_sentence_duration = 0.0
@@ -628,7 +625,7 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
self._last_timestamp = None
self._current_context_id = None
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio data."""
logger.trace(f"{self}: flushing audio")
@@ -694,9 +691,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
return
try:
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
self._first_chunk = True
self._current_context_id = context_id
# Capture base offset BEFORE starting synthesis to avoid race conditions
@@ -719,11 +713,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
yield ErrorFrame(error=str(chunk))
break
if self._first_chunk:
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
self._first_chunk = False
frame = TTSAudioRawFrame(
audio=chunk,
sample_rate=self.sample_rate,
@@ -833,6 +822,8 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -887,8 +878,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
await self.start_ttfb_metrics()
ssml = self._construct_ssml(text)
result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml)
@@ -896,7 +885,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
if result.reason == ResultReason.SynthesizingAudioCompleted:
await self.start_tts_usage_metrics(text)
await self.stop_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
# Azure always sends a 44-byte header. Strip it off.
yield TTSAudioRawFrame(
audio=result.audio_data[44:],
@@ -904,7 +892,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
num_channels=1,
context_id=context_id,
)
yield TTSStoppedFrame(context_id=context_id)
elif result.reason == ResultReason.Canceled:
cancellation_details = result.cancellation_details
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")

View File

@@ -29,8 +29,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -271,6 +269,8 @@ class CambTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -332,8 +332,6 @@ class CambTTSService(TTSService):
text = text[:3000]
try:
await self.start_ttfb_metrics()
# Build SDK parameters
tts_kwargs: Dict[str, Any] = {
"text": text,
@@ -348,7 +346,6 @@ class CambTTSService(TTSService):
tts_kwargs["user_instructions"] = self._settings.user_instructions
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
assert self._client is not None, "Camb.ai TTS service not initialized"
@@ -384,5 +381,3 @@ class CambTTSService(TTSService):
except Exception as e:
yield ErrorFrame(error=f"Camb.ai TTS error: {e}")
finally:
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
@@ -203,7 +203,7 @@ class CartesiaTTSSettings(TTSSettings):
pronunciation_dict_id: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
class CartesiaTTSService(AudioContextTTSService):
class CartesiaTTSService(WebsocketTTSService):
"""Cartesia TTS service with WebSocket streaming and word timestamps.
Provides text-to-speech using Cartesia's streaming WebSocket API.
@@ -334,9 +334,9 @@ class CartesiaTTSService(AudioContextTTSService):
text_aggregation_mode=text_aggregation_mode,
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
pause_frame_processing=True,
supports_word_timestamps=True,
pause_frame_processing=False,
sample_rate=sample_rate,
push_start_frame=True,
text_aggregator=text_aggregator,
settings=default_settings,
**kwargs,
@@ -452,7 +452,11 @@ class CartesiaTTSService(AudioContextTTSService):
return list(zip(words, starts))
def _build_msg(
self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True
self,
text: str = "",
continue_transcript: bool = True,
add_timestamps: bool = True,
context_id: str = "",
):
voice_config = {}
voice_config["mode"] = "id"
@@ -461,7 +465,7 @@ class CartesiaTTSService(AudioContextTTSService):
msg = {
"transcript": text,
"continue": continue_transcript,
"context_id": self.get_active_audio_context_id(),
"context_id": context_id,
"model_id": self._settings.model,
"voice": voice_config,
"output_format": {
@@ -580,15 +584,19 @@ class CartesiaTTSService(AudioContextTTSService):
"""
pass
async def flush_audio(self):
"""Flush any pending audio and finalize the current context."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio and finalize the current context.
Args:
context_id: The specific context to flush. If None, falls back to the
currently active context.
"""
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
msg = self._build_msg(text="", continue_transcript=False)
msg = self._build_msg(text="", continue_transcript=False, context_id=flush_id)
await self._websocket.send(msg)
self.reset_active_audio_context()
async def _process_messages(self):
async for message in self._get_websocket():
@@ -607,8 +615,6 @@ class CartesiaTTSService(AudioContextTTSService):
)
await self.add_word_timestamps(processed_timestamps, ctx_id)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self.sample_rate,
@@ -652,12 +658,7 @@ class CartesiaTTSService(AudioContextTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
if not self.has_active_audio_context():
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
await self.create_audio_context(context_id)
msg = self._build_msg(text=text)
msg = self._build_msg(text=text, context_id=context_id)
try:
await self._get_websocket().send(msg)
@@ -777,6 +778,8 @@ class CartesiaHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -863,8 +866,6 @@ class CartesiaHttpTTSService(TTSService):
try:
voice_config = {"mode": "id", "id": self._settings.voice}
await self.start_ttfb_metrics()
output_format = {
"container": self._output_container,
"encoding": self._output_encoding,
@@ -889,8 +890,6 @@ class CartesiaHttpTTSService(TTSService):
if self._settings.pronunciation_dict_id:
payload["pronunciation_dict_id"] = self._settings.pronunciation_dict_id
yield TTSStartedFrame(context_id=context_id)
headers = {
"Cartesia-Version": self._cartesia_version,
"X-API-Key": self._api_key,
@@ -922,4 +921,3 @@ class CartesiaHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -328,7 +328,7 @@ class DeepgramSageMakerTTSService(TTSService):
except Exception as e:
logger.error(f"{self} error sending Clear message: {e}")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis by sending Flush command.
This should be called when the LLM finishes a complete response to force
@@ -355,12 +355,12 @@ class DeepgramSageMakerTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
if not self._ttfb_started:
await self.start_ttfb_metrics()
self._ttfb_started = True
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
if not self._ttfb_started:
await self.start_ttfb_metrics()
self._ttfb_started = True
yield TTSStartedFrame(context_id=context_id)
self._context_id = context_id
await self._client.send_json({"type": "Speak", "text": text})

View File

@@ -26,8 +26,6 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
@@ -120,6 +118,7 @@ class DeepgramTTSService(WebsocketTTSService):
sample_rate=sample_rate,
pause_frame_processing=True,
push_stop_frames=True,
push_start_frame=True,
append_trailing_space=True,
settings=default_settings,
**kwargs,
@@ -130,7 +129,6 @@ class DeepgramTTSService(WebsocketTTSService):
self._encoding = encoding
self._receive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if the service can generate metrics.
@@ -267,7 +265,6 @@ class DeepgramTTSService(WebsocketTTSService):
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -299,7 +296,9 @@ class DeepgramTTSService(WebsocketTTSService):
if isinstance(message, bytes):
# Binary message contains audio data
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self.sample_rate, 1, context_id=self._context_id)
frame = TTSAudioRawFrame(
message, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
)
await self.push_frame(frame)
elif isinstance(message, str):
# Text message contains metadata or control messages
@@ -326,7 +325,7 @@ class DeepgramTTSService(WebsocketTTSService):
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis by sending Flush command.
This should be called when the LLM finishes a complete response to force
@@ -357,13 +356,8 @@ class DeepgramTTSService(WebsocketTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
# Store context_id for use in _receive_messages
self._context_id = context_id
# Send text message to Deepgram
# Note: We don't send Flush here - that should only be sent when the
# LLM finishes a complete response via flush_audio()
@@ -435,6 +429,8 @@ class DeepgramHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -492,7 +488,6 @@ class DeepgramHttpTTSService(TTSService):
raise Exception(f"HTTP {response.status}: {error_text}")
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
@@ -510,7 +505,5 @@ class DeepgramHttpTTSService(TTSService):
context_id=context_id,
)
yield TTSStoppedFrame(context_id=context_id)
except Exception as e:
yield ErrorFrame(f"Error getting audio: {str(e)}")

View File

@@ -46,9 +46,9 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import (
AudioContextTTSService,
TextAggregationMode,
TTSService,
WebsocketTTSService,
)
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -308,7 +308,7 @@ def calculate_word_times(
return (word_times, new_partial_word, new_partial_word_start_time)
class ElevenLabsTTSService(AudioContextTTSService):
class ElevenLabsTTSService(WebsocketTTSService):
"""ElevenLabs WebSocket-based TTS service with word timestamps.
Provides real-time text-to-speech using ElevenLabs' WebSocket streaming API.
@@ -479,7 +479,6 @@ class ElevenLabsTTSService(AudioContextTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -559,20 +558,15 @@ class ElevenLabsTTSService(AudioContextTTSService):
)
await self._disconnect()
await self._connect()
elif voice_settings_changed and self.has_active_audio_context():
elif voice_settings_changed:
logger.debug(
f"Voice settings changed ({changed.keys() & ElevenLabsTTSSettings.VOICE_SETTINGS_FIELDS}), "
f"closing current context to apply changes"
)
context_id = self.get_active_audio_context_id()
try:
if self._websocket:
await self._websocket.send(
json.dumps({"context_id": context_id, "close_context": True})
)
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self.reset_active_audio_context()
audio_contexts = self.get_audio_contexts()
if audio_contexts:
for ctx_id in audio_contexts:
await self._close_context(ctx_id)
if not url_changed:
# Reconnect applies all settings; only warn about fields not handled
@@ -610,13 +604,18 @@ class ElevenLabsTTSService(AudioContextTTSService):
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self):
"""Flush any pending audio and finalize the current context."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio and finalize the current context.
Args:
context_id: The specific context to flush. If None, falls back to the
currently active context.
"""
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
msg = {"context_id": context_id, "flush": True}
msg = {"context_id": flush_id, "flush": True}
await self._websocket.send(json.dumps(msg))
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -703,9 +702,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
if self._websocket:
logger.debug("Disconnecting from ElevenLabs")
# Close all contexts and the socket
if self.has_active_audio_context():
await self._websocket.send(json.dumps({"close_socket": True}))
await self._websocket.send(json.dumps({"close_socket": True}))
await self._websocket.close()
logger.debug("Disconnected from ElevenLabs")
except Exception as e:
@@ -737,6 +734,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
)
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._cumulative_time = 0.0
self._partial_word = ""
self._partial_word_start_time = 0.0
@@ -782,9 +780,6 @@ class ElevenLabsTTSService(AudioContextTTSService):
continue
if msg.get("audio"):
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=received_ctx_id)
await self.append_to_audio_context(received_ctx_id, frame)
@@ -845,9 +840,8 @@ class ElevenLabsTTSService(AudioContextTTSService):
logger.warning(f"{self} keepalive error: {e}")
break
async def _send_text(self, text: str):
async def _send_text(self, text: str, context_id: str):
"""Send text to the WebSocket for synthesis."""
context_id = self.get_active_audio_context_id()
if self._websocket and context_id:
msg = {"text": text, "context_id": context_id}
await self._websocket.send(json.dumps(msg))
@@ -870,16 +864,14 @@ class ElevenLabsTTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
self._cumulative_time = 0
self._partial_word = ""
self._partial_word_start_time = 0.0
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
# Initialize context with voice settings and pronunciation dictionaries
msg = {"text": " ", "context_id": context_id}
if self._voice_settings:
@@ -892,7 +884,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
await self._websocket.send(json.dumps(msg))
logger.trace(f"Created new context {context_id}")
await self._send_text(text)
await self._send_text(text, context_id)
await self.start_tts_usage_metrics(text)
except Exception as e:
yield TTSStoppedFrame(context_id=context_id)
@@ -1046,7 +1038,7 @@ class ElevenLabsHttpTTSService(TTSService):
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
push_start_frame=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -1266,8 +1258,6 @@ class ElevenLabsHttpTTSService(TTSService):
params["optimize_streaming_latency"] = self._settings.optimize_streaming_latency
try:
await self.start_ttfb_metrics()
async with self._session.post(
url, json=payload, headers=headers, params=params
) as response:
@@ -1278,10 +1268,6 @@ class ElevenLabsHttpTTSService(TTSService):
await self.start_tts_usage_metrics(text)
# Start TTS sequence
await self.start_word_timestamps()
yield TTSStartedFrame(context_id=context_id)
# Track the duration of this utterance based on the last character's end time
utterance_duration = 0
async for line in response.content:
@@ -1347,4 +1333,3 @@ class ElevenLabsHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
# Let the parent class handle TTSStoppedFrame

View File

@@ -209,6 +209,7 @@ class FishAudioTTSService(InterruptibleTTSService):
super().__init__(
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
sample_rate=sample_rate,
settings=default_settings,
@@ -219,7 +220,6 @@ class FishAudioTTSService(InterruptibleTTSService):
self._base_url = "wss://api.fish.audio/v1/tts/live"
self._websocket = None
self._receive_task = None
self._request_id = None
# Init-only audio format config (not runtime-updatable).
self._fish_sample_rate = 0 # Set in start()
@@ -341,11 +341,10 @@ class FishAudioTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._request_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any buffered audio by sending a flush event to Fish Audio."""
logger.trace(f"{self}: Flushing audio buffers")
if not self._websocket or self._websocket.state is State.CLOSED:
@@ -361,7 +360,6 @@ class FishAudioTTSService(InterruptibleTTSService):
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
self._request_id = None
async def _receive_messages(self):
async for message in self._get_websocket():
@@ -398,12 +396,6 @@ class FishAudioTTSService(InterruptibleTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
if not self._request_id:
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
self._request_id = str(uuid.uuid4())
# Send the text
text_message = {
"event": "text",

View File

@@ -34,8 +34,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import (
NOT_GIVEN,
@@ -655,6 +653,8 @@ class GoogleHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -803,8 +803,6 @@ class GoogleHttpTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Check if the voice is a Chirp voice (including Chirp 3) or Journey voice
is_chirp_voice = "chirp" in self._settings.voice.lower()
is_journey_voice = "journey" in self._settings.voice.lower()
@@ -840,8 +838,6 @@ class GoogleHttpTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
# Skip the first 44 bytes to remove the WAV header
audio_content = response.audio_content[44:]
@@ -855,8 +851,6 @@ class GoogleHttpTTSService(TTSService):
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
yield frame
yield TTSStoppedFrame(context_id=context_id)
except Exception as e:
error_message = f"TTS generation error: {str(e)}"
yield ErrorFrame(error=error_message)
@@ -967,8 +961,6 @@ class GoogleBaseTTSService(TTSService):
streaming_responses = await self._client.streaming_synthesize(request_generator())
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
audio_buffer = b""
first_chunk_for_ttfb = False
@@ -992,8 +984,6 @@ class GoogleBaseTTSService(TTSService):
if audio_buffer:
yield TTSAudioRawFrame(audio_buffer, self.sample_rate, 1, context_id=context_id)
yield TTSStoppedFrame(context_id=context_id)
class GoogleTTSService(GoogleBaseTTSService):
"""Google Cloud Text-to-Speech streaming service.
@@ -1096,6 +1086,8 @@ class GoogleTTSService(GoogleBaseTTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -1135,8 +1127,6 @@ class GoogleTTSService(GoogleBaseTTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Build voice selection params
if self._voice_cloning_key:
voice_clone_params = texttospeech_v1.VoiceCloneParams(
@@ -1352,6 +1342,8 @@ class GeminiTTSService(GoogleBaseTTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -1414,8 +1406,6 @@ class GeminiTTSService(GoogleBaseTTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Build voice selection params
if self._settings.multi_speaker and self._settings.speaker_configs:
# Multi-speaker mode

View File

@@ -19,11 +19,10 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import AudioContextTTSService
from pipecat.services.tts_service import WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -45,7 +44,7 @@ class GradiumTTSSettings(TTSSettings):
pass
class GradiumTTSService(AudioContextTTSService):
class GradiumTTSService(WebsocketTTSService):
"""Text-to-Speech service using Gradium's websocket API."""
_settings: GradiumTTSSettings
@@ -125,9 +124,9 @@ class GradiumTTSService(AudioContextTTSService):
super().__init__(
push_stop_frames=True,
push_start_frame=True,
push_text_frames=False,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=SAMPLE_RATE,
settings=default_settings,
**kwargs,
@@ -166,12 +165,9 @@ class GradiumTTSService(AudioContextTTSService):
self._warn_unhandled_updated_settings(changed)
return changed
def _build_msg(self, text: str = "") -> dict:
def _build_msg(self, text: str = "", context_id: str = "") -> dict:
"""Build JSON message for Gradium API."""
msg = {"text": text, "type": "text"}
context_id = self.get_active_audio_context_id()
if context_id:
msg["client_req_id"] = context_id
msg = {"text": text, "type": "text", "client_req_id": context_id}
return msg
async def start(self, frame: StartFrame):
@@ -280,15 +276,14 @@ class GradiumTTSService(AudioContextTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
try:
msg = {"type": "end_of_stream", "client_req_id": context_id}
msg = {"type": "end_of_stream", "client_req_id": flush_id}
await self._websocket.send(json.dumps(msg))
self.reset_active_audio_context()
except ConnectionClosedOK:
logger.debug(f"{self}: connection closed normally during flush")
except Exception as e:
@@ -326,8 +321,6 @@ class GradiumTTSService(AudioContextTTSService):
if msg["type"] == "audio":
if not ctx_id or not self.audio_context_available(ctx_id):
continue
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["audio"]),
sample_rate=self.sample_rate,
@@ -369,12 +362,7 @@ class GradiumTTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
await self.create_audio_context(context_id)
msg = self._build_msg(text=text)
msg = self._build_msg(text=text, context_id=context_id)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:

View File

@@ -18,8 +18,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -140,6 +138,8 @@ class GroqTTSService(TTSService):
super().__init__(
pause_frame_processing=True,
push_start_frame=True,
push_stop_frames=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -171,9 +171,6 @@ class GroqTTSService(TTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
measuring_ttfb = True
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
try:
response = await self._client.audio.speech.create(
model=self._settings.model,
@@ -198,5 +195,3 @@ class GroqTTSService(TTSService):
yield TTSAudioRawFrame(bytes, frame_rate, channels, context_id=context_id)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -22,7 +22,6 @@ from pipecat.frames.frames import (
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
@@ -166,7 +165,7 @@ class HumeTTSService(TTSService):
sample_rate=sample_rate,
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
push_start_frame=True,
settings=default_settings,
**kwargs,
)
@@ -181,7 +180,6 @@ class HumeTTSService(TTSService):
# Track cumulative time for word timestamps across utterances
self._cumulative_time = 0.0
self._started = False
def can_generate_metrics(self) -> bool:
"""Can generate metrics.
@@ -203,7 +201,6 @@ class HumeTTSService(TTSService):
def _reset_state(self):
"""Reset internal state variables."""
self._cumulative_time = 0.0
self._started = False
async def stop(self, frame: EndFrame) -> None:
"""Stop the service and cleanup resources.
@@ -310,15 +307,8 @@ class HumeTTSService(TTSService):
# Request raw PCM chunks in the streaming JSON
pcm_fmt = FormatPcm(type="pcm")
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
# Start TTS sequence if not already started
if not self._started:
await self.start_word_timestamps()
yield TTSStartedFrame(context_id=context_id)
self._started = True
try:
# Instant mode is always enabled here (not user-configurable)
# Hume emits mono PCM at 48 kHz; downstream can resample if needed.
@@ -395,4 +385,3 @@ class HumeTTSService(TTSService):
finally:
# Ensure TTFB timer is stopped even on early failures
await self.stop_ttfb_metrics()
# Let the parent class handle TTSStoppedFrame via push_stop_frames

View File

@@ -62,7 +62,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -212,7 +212,7 @@ class InworldHttpTTSService(TTSService):
super().__init__(
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
push_start_frame=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -359,11 +359,6 @@ class InworldHttpTTSService(TTSService):
}
try:
await self.start_ttfb_metrics()
await self.start_word_timestamps()
yield TTSStartedFrame(context_id=context_id)
async with self._session.post(
self._base_url, json=payload, headers=headers
) as response:
@@ -514,7 +509,7 @@ class InworldHttpTTSService(TTSService):
)
class InworldTTSService(AudioContextTTSService):
class InworldTTSService(WebsocketTTSService):
"""Inworld AI WebSocket-based TTS service.
Uses bidirectional WebSocket for lower latency streaming. Supports multiple
@@ -650,7 +645,6 @@ class InworldTTSService(AudioContextTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
aggregate_sentences=aggregate_sentences,
text_aggregation_mode=text_aggregation_mode,
@@ -719,17 +713,17 @@ class InworldTTSService(AudioContextTTSService):
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio without closing the context.
This triggers synthesis of all accumulated text in the buffer while
keeping the context open for subsequent text. The context is only
closed on interruption, disconnect, or end of session.
"""
context_id = self.get_active_audio_context_id()
if context_id and self._websocket:
logger.trace(f"Flushing audio for context {context_id}")
await self._send_flush(context_id)
flush_id = context_id or self.get_active_audio_context_id()
if flush_id and self._websocket:
logger.trace(f"Flushing audio for context {flush_id}")
await self._send_flush(flush_id)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame and handle state changes.
@@ -899,12 +893,10 @@ class InworldTTSService(AudioContextTTSService):
if self._websocket:
logger.debug("Disconnecting from Inworld WebSocket TTS")
context_id = self.get_active_audio_context_id()
if context_id:
try:
await self._send_close_context(context_id)
except Exception:
pass
audio_contexts = self.get_audio_contexts()
if audio_contexts:
for ctx_id in audio_contexts:
await self._send_close_context(ctx_id)
await self._websocket.close()
logger.debug("Disconnected from Inworld WebSocket TTS")
except Exception as e:
@@ -934,10 +926,7 @@ class InworldTTSService(AudioContextTTSService):
for k in ["contextCreated", "audioChunk", "flushCompleted", "contextClosed"]
if k in result
]
logger.debug(
f"{self}: Received message types={msg_types}, ctx_id={ctx_id}, "
f"current_ctx={self.get_active_audio_context_id()}, available={self.audio_context_available(ctx_id) if ctx_id else 'N/A'}"
)
logger.debug(f"{self}: Received message types={msg_types}, ctx_id={ctx_id}")
# Check for errors
status = result.get("status", {})
@@ -948,9 +937,7 @@ class InworldTTSService(AudioContextTTSService):
# Handle "Context not found" error (code 5)
# This can happen when a keepalive message is sent but no context is available.
if error_code == 5 and "not found" in error_msg.lower():
logger.debug(
f"{self}: Context {ctx_id or self.get_active_audio_context_id()} not found."
)
logger.debug(f"{self}: Context {ctx_id} not found.")
continue
# For other errors, push error frame
@@ -961,17 +948,10 @@ class InworldTTSService(AudioContextTTSService):
await self.push_error(error_msg=str(msg["error"]))
continue
# Check if this message belongs to an available context.
# If the context isn't available but matches our current context ID,
# recreate it (handles race conditions during interruption recovery).
# If the context isn't available recreate it (handles race conditions during interruption recovery).
if ctx_id and not self.audio_context_available(ctx_id):
if self.get_active_audio_context_id() == ctx_id:
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
await self.create_audio_context(ctx_id)
else:
# This is a message from an old/closed context - skip it
logger.trace(f"{self}: Skipping message from unavailable context: {ctx_id}")
continue
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
await self.create_audio_context(ctx_id)
# Process audio chunk
audio_chunk = result.get("audioChunk", {})
@@ -979,8 +959,6 @@ class InworldTTSService(AudioContextTTSService):
if audio_b64:
logger.trace(f"{self}: Processing audio chunk for context {ctx_id}")
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
audio = base64.b64decode(audio_b64)
if len(audio) > 44 and audio.startswith(b"RIFF"):
audio = audio[44:]
@@ -1012,12 +990,8 @@ class InworldTTSService(AudioContextTTSService):
if "contextClosed" in result:
logger.trace(f"{self}: Context closed on server: {ctx_id}")
await self.stop_ttfb_metrics()
# Only reset if this is our current context
if ctx_id == self.get_active_audio_context_id():
self.reset_active_audio_context()
if ctx_id and self.audio_context_available(ctx_id):
await self.remove_audio_context(ctx_id)
await self.add_word_timestamps([("TTSStoppedFrame", 0), ("Reset", 0)], ctx_id)
await self.remove_audio_context(ctx_id)
async def _keepalive_task_handler(self):
"""Send periodic keepalive messages to maintain WebSocket connection."""
@@ -1128,10 +1102,10 @@ class InworldTTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
await self.create_audio_context(context_id)
await self._send_context(context_id)
await self._send_text(context_id, text)

View File

@@ -20,8 +20,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -170,6 +168,8 @@ class KokoroTTSService(TTSService):
default_settings.apply_update(settings)
super().__init__(
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -212,9 +212,7 @@ class KokoroTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
stream = self._kokoro.create_stream(
text, voice=self._settings.voice, lang=self._settings.language, speed=1.0
@@ -238,4 +236,3 @@ class KokoroTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -143,6 +143,7 @@ class LmntTTSService(InterruptibleTTSService):
super().__init__(
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
sample_rate=sample_rate,
settings=default_settings,
@@ -152,7 +153,6 @@ class LmntTTSService(InterruptibleTTSService):
self._api_key = api_key
self._output_format = "raw"
self._receive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -289,7 +289,6 @@ class LmntTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Error disconnecting from LMNT: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -299,7 +298,7 @@ class LmntTTSService(InterruptibleTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
if not self._websocket or self._websocket.state is State.CLOSED:
return
@@ -315,7 +314,7 @@ class LmntTTSService(InterruptibleTTSService):
audio=message,
sample_rate=self.sample_rate,
num_channels=1,
context_id=self._context_id,
context_id=self.get_active_audio_context_id(),
)
await self.push_frame(frame)
else:
@@ -347,11 +346,6 @@ class LmntTTSService(InterruptibleTTSService):
await self._connect()
try:
await self.start_ttfb_metrics()
# Store context_id for use in _receive_messages
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
# Send text to LMNT
await self._get_websocket().send(json.dumps({"text": text}))
# Force synthesis

View File

@@ -23,8 +23,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -305,6 +303,8 @@ class MiniMaxHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -402,8 +402,6 @@ class MiniMaxHttpTTSService(TTSService):
payload["language_boost"] = self._settings.language_boost
try:
await self.start_ttfb_metrics()
async with self._session.post(
self._base_url, headers=headers, json=payload
) as response:
@@ -413,7 +411,6 @@ class MiniMaxHttpTTSService(TTSService):
return
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
# Process the streaming response
buffer = bytearray()
@@ -490,4 +487,3 @@ class MiniMaxHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -180,6 +180,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
aggregate_sentences=aggregate_sentences,
text_aggregation_mode=text_aggregation_mode,
push_stop_frames=True,
push_start_frame=True,
stop_frame_timeout_s=2.0,
sample_rate=sample_rate,
settings=default_settings,
@@ -188,12 +189,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
self._api_key = api_key
self._url = url
self._cumulative_time = 0
self._receive_task = None
self._keepalive_task = None
self._context_id: Optional[str] = None
self._encoding = encoding
self._sampling_rate = sample_rate
@@ -252,7 +249,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis by sending stop command."""
if self._websocket:
msg = {"text": "<STOP>"}
@@ -358,7 +355,6 @@ class NeuphonicTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -372,7 +368,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
audio = base64.b64decode(msg["data"]["audio"])
frame = TTSAudioRawFrame(
audio, self.sample_rate, 1, context_id=self._context_id
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
)
await self.push_frame(frame)
@@ -415,12 +411,6 @@ class NeuphonicTTSService(InterruptibleTTSService):
await self._connect()
try:
await self.start_ttfb_metrics()
# Store context_id for use in _receive_messages
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
self._cumulative_time = 0
await self._send_text(text)
await self.start_tts_usage_metrics(text)
except Exception as e:
@@ -523,6 +513,8 @@ class NeuphonicHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_stop_frames=True,
push_start_frame=True,
settings=default_settings,
**kwargs,
)
@@ -559,7 +551,7 @@ class NeuphonicHttpTTSService(TTSService):
"""
await super().start(frame)
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis.
Note:
@@ -633,8 +625,6 @@ class NeuphonicHttpTTSService(TTSService):
payload["voice_id"] = self._settings.voice
try:
await self.start_ttfb_metrics()
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
@@ -643,7 +633,6 @@ class NeuphonicHttpTTSService(TTSService):
return
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
# Process SSE stream line by line
async for line in response.content:
@@ -681,4 +670,3 @@ class NeuphonicHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -28,8 +28,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -145,6 +143,8 @@ class NvidiaTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -271,9 +271,6 @@ class NvidiaTTSService(TTSService):
assert self._service is not None, "TTS service not initialized"
assert self._config is not None, "Synthesis configuration not created"
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
logger.debug(f"{self}: Generating TTS [{text}]")
responses = await asyncio.to_thread(read_audio_responses)
@@ -289,7 +286,6 @@ class NvidiaTTSService(TTSService):
yield frame
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame(context_id=context_id)
except asyncio.TimeoutError as e:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")

View File

@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -194,6 +192,8 @@ class OpenAITTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -234,8 +234,6 @@ class OpenAITTSService(TTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Setup API parameters
create_params = {
"input": text,
@@ -267,12 +265,10 @@ class OpenAITTSService(TTSService):
CHUNK_SIZE = self.chunk_size
yield TTSStartedFrame(context_id=context_id)
async for chunk in r.iter_bytes(CHUNK_SIZE):
if len(chunk) > 0:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
yield frame
yield TTSStoppedFrame(context_id=context_id)
except BadRequestError as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -17,8 +17,6 @@ from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -91,6 +89,8 @@ class PiperTTSService(TTSService):
default_settings.apply_update(settings)
super().__init__(
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -159,12 +159,8 @@ class PiperTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
async for frame in self._stream_audio_frames_from_iterator(
async_iterator(self._voice.synthesize(text)),
in_sample_rate=self._voice.config.sample_rate,
@@ -178,7 +174,6 @@ class PiperTTSService(TTSService):
finally:
logger.debug(f"{self}: Finished TTS [{text}]")
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)
# This assumes a running TTS service running:
@@ -244,6 +239,8 @@ class PiperHttpTTSService(TTSService):
default_settings.apply_update(settings)
super().__init__(
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -279,8 +276,6 @@ class PiperHttpTTSService(TTSService):
"Content-Type": "application/json",
}
try:
await self.start_ttfb_metrics()
data = {
"text": text,
"voice": self._settings.voice,
@@ -296,8 +291,6 @@ class PiperHttpTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
async for frame in self._stream_audio_frames_from_iterator(
@@ -311,4 +304,3 @@ class PiperHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -24,7 +24,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import AudioContextTTSService
from pipecat.services.tts_service import WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -43,7 +43,7 @@ class ResembleAITTSSettings(TTSSettings):
pass
class ResembleAITTSService(AudioContextTTSService):
class ResembleAITTSService(WebsocketTTSService):
"""Resemble AI TTS service with WebSocket streaming and word timestamps.
Provides text-to-speech using Resemble AI's streaming WebSocket API.
@@ -103,7 +103,6 @@ class ResembleAITTSService(AudioContextTTSService):
super().__init__(
sample_rate=sample_rate,
reuse_context_id_within_turn=False,
supports_word_timestamps=True,
settings=default_settings,
**kwargs,
)
@@ -268,7 +267,7 @@ class ResembleAITTSService(AudioContextTTSService):
"""
pass
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio and finalize the current context."""
logger.trace(f"{self}: flushing audio")
# For Resemble AI, we just wait for the audio_end message
@@ -297,9 +296,6 @@ class ResembleAITTSService(AudioContextTTSService):
continue
if msg_type == "audio":
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
# Decode base64 audio content
audio_content = msg.get("audio_content", "")
if not audio_content:
@@ -447,14 +443,14 @@ class ResembleAITTSService(AudioContextTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
# Map request_id to context_id for tracking
self._request_id_to_context[self._request_id_counter] = context_id
await self.create_audio_context(context_id)
msg = self._build_msg(text=text)
try:

View File

@@ -33,10 +33,10 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import (
AudioContextTTSService,
InterruptibleTTSService,
TextAggregationMode,
TTSService,
WebsocketTTSService,
)
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
@@ -123,7 +123,7 @@ class RimeNonJsonTTSSettings(TTSSettings):
_aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"}
class RimeTTSService(AudioContextTTSService):
class RimeTTSService(WebsocketTTSService):
"""Text-to-Speech service using Rime's websocket API.
Uses Rime's websocket JSON API to convert text to speech with word-level timing
@@ -276,7 +276,6 @@ class RimeTTSService(AudioContextTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
append_trailing_space=True,
sample_rate=sample_rate,
settings=default_settings,
@@ -408,9 +407,9 @@ class RimeTTSService(AudioContextTTSService):
return changed
def _build_msg(self, text: str = "") -> dict:
def _build_msg(self, text: str = "", context_id: str = "") -> dict:
"""Build JSON message for Rime API."""
msg = {"text": text, "contextId": self.get_active_audio_context_id()}
msg = {"text": text, "contextId": context_id}
if self._extra_msg_fields:
msg |= self._extra_msg_fields
self._extra_msg_fields = {}
@@ -557,15 +556,14 @@ class RimeTTSService(AudioContextTTSService):
return word_pairs
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
await self._get_websocket().send(json.dumps({"operation": "flush"}))
self.reset_active_audio_context()
async def _receive_messages(self):
"""Process incoming websocket messages."""
@@ -578,8 +576,6 @@ class RimeTTSService(AudioContextTTSService):
context_id = msg["contextId"]
if msg["type"] == "chunk":
# Process audio chunk
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self.sample_rate,
@@ -638,13 +634,13 @@ class RimeTTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
self._cumulative_time = 0
await self.create_audio_context(context_id)
msg = self._build_msg(text=text)
msg = self._build_msg(text=text, context_id=context_id)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:
@@ -773,6 +769,8 @@ class RimeHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_stop_frames=True,
push_start_frame=True,
settings=default_settings,
**kwargs,
)
@@ -844,8 +842,6 @@ class RimeHttpTTSService(TTSService):
need_to_strip_wav_header = False
try:
await self.start_ttfb_metrics()
async with self._session.post(
self._base_url, json=payload, headers=headers
) as response:
@@ -856,8 +852,6 @@ class RimeHttpTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
async for frame in self._stream_audio_frames_from_iterator(
@@ -872,7 +866,6 @@ class RimeHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)
class RimeNonJsonTTSService(InterruptibleTTSService):
@@ -1005,6 +998,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
aggregate_sentences=aggregate_sentences,
text_aggregation_mode=text_aggregation_mode,
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
append_trailing_space=True,
settings=default_settings,
@@ -1022,7 +1016,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
self._settings.extra.update(params.extra)
self._receive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -1138,7 +1131,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -1148,7 +1140,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
if not self._websocket:
return
@@ -1168,7 +1160,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
audio=message,
sample_rate=self.sample_rate,
num_channels=1,
context_id=self._context_id,
context_id=self.get_active_audio_context_id(),
)
await self.push_frame(frame)
except Exception as e:
@@ -1190,10 +1182,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
try:
await self.start_ttfb_metrics()
# Store context_id for use in _receive_messages
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
# Send bare text (not JSON)
await self._get_websocket().send(text)
await self.start_tts_usage_metrics(text)

View File

@@ -524,6 +524,8 @@ class SarvamHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_stop_frames=True,
push_start_frame=True,
settings=default_settings,
**kwargs,
)
@@ -573,8 +575,6 @@ class SarvamHttpTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Build payload with common parameters
payload = {
"text": text,
@@ -606,8 +606,6 @@ class SarvamHttpTTSService(TTSService):
url = f"{self._base_url}/text-to-speech"
yield TTSStartedFrame(context_id=context_id)
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
@@ -645,7 +643,6 @@ class SarvamHttpTTSService(TTSService):
yield ErrorFrame(error=f"Error generating TTS: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)
class SarvamTTSService(InterruptibleTTSService):
@@ -951,6 +948,7 @@ class SarvamTTSService(InterruptibleTTSService):
push_text_frames=True,
pause_frame_processing=True,
push_stop_frames=True,
push_start_frame=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -967,7 +965,6 @@ class SarvamTTSService(InterruptibleTTSService):
self._receive_task = None
self._keepalive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -1018,7 +1015,7 @@ class SarvamTTSService(InterruptibleTTSService):
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis by sending flush command."""
try:
if self._websocket:
@@ -1151,7 +1148,6 @@ class SarvamTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -1170,7 +1166,7 @@ class SarvamTTSService(InterruptibleTTSService):
await self.stop_ttfb_metrics()
audio = base64.b64decode(msg["data"]["audio"])
frame = TTSAudioRawFrame(
audio, self.sample_rate, 1, context_id=self._context_id
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
)
await self.push_frame(frame)
elif msg.get("type") == "error":
@@ -1224,10 +1220,6 @@ class SarvamTTSService(InterruptibleTTSService):
await self._connect()
try:
await self.start_ttfb_metrics()
# Store context_id for use in _receive_messages
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
await self._send_text(text)
await self.start_tts_usage_metrics(text)
except Exception as e:

View File

@@ -19,8 +19,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -135,6 +133,8 @@ class SpeechmaticsTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -185,9 +185,6 @@ class SpeechmaticsTTSService(TTSService):
url = _get_endpoint_url(self._base_url, self._settings.voice, self.sample_rate)
try:
# Start TTS TTFB metrics
await self.start_ttfb_metrics()
# Track attempt
attempt = 0
@@ -238,9 +235,6 @@ class SpeechmaticsTTSService(TTSService):
# Update Pipecat metrics
await self.start_tts_usage_metrics(text)
# Emit the TTS started frame
yield TTSStartedFrame(context_id=context_id)
# Process the response in streaming chunks
first_chunk = True
buffer = b""
@@ -277,8 +271,7 @@ class SpeechmaticsTTSService(TTSService):
except Exception as e:
yield ErrorFrame(error=f"Error generating TTS: {e}")
finally:
# Emit the TTS stopped frame
yield TTSStoppedFrame(context_id=context_id)
await self.stop_ttfb_metrics()
def _get_endpoint_url(base_url: str, voice: str, sample_rate: int) -> str:

View File

@@ -11,7 +11,7 @@ text-to-speech synthesis using local Docker deployment.
"""
from dataclasses import dataclass
from typing import AsyncGenerator, Dict, Optional
from typing import Any, AsyncGenerator, Dict, Optional
import aiohttp
from loguru import logger
@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -132,6 +130,8 @@ class XTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -213,8 +213,6 @@ class XTTSService(TTSService):
"stream_chunk_size": 20,
}
await self.start_ttfb_metrics()
async with self._aiohttp_session.post(url, json=payload) as r:
if r.status != 200:
text = await r.text()
@@ -223,8 +221,6 @@ class XTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
buffer = bytearray()
@@ -262,5 +258,3 @@ class XTTSService(TTSService):
resampled_audio, self.sample_rate, 1, context_id=context_id
)
yield frame
yield TTSStoppedFrame(context_id=context_id)