From 88ff7c451ba757dc25e898c626b0b476f4d9aa5b Mon Sep 17 00:00:00 2001 From: filipi87 Date: Fri, 6 Mar 2026 16:15:59 -0300 Subject: [PATCH] Refactored all 25+ TTS service implementations to use the new push_start_frame=True pattern --- src/pipecat/services/asyncai/tts.py | 36 +++++----- src/pipecat/services/aws/tts.py | 12 +--- src/pipecat/services/azure/tts.py | 21 ++---- src/pipecat/services/camb/tts.py | 9 +-- src/pipecat/services/cartesia/tts.py | 48 +++++++------- .../services/deepgram/sagemaker/tts.py | 14 ++-- src/pipecat/services/deepgram/tts.py | 21 ++---- src/pipecat/services/elevenlabs/tts.py | 63 +++++++----------- src/pipecat/services/fish/tts.py | 12 +--- src/pipecat/services/google/tts.py | 22 ++----- src/pipecat/services/gradium/tts.py | 32 +++------ src/pipecat/services/groq/tts.py | 9 +-- src/pipecat/services/hume/tts.py | 13 +--- src/pipecat/services/inworld/tts.py | 66 ++++++------------- src/pipecat/services/kokoro/tts.py | 7 +- src/pipecat/services/lmnt/tts.py | 12 +--- src/pipecat/services/minimax/tts.py | 8 +-- src/pipecat/services/neuphonic/tts.py | 24 ++----- src/pipecat/services/nvidia/tts.py | 8 +-- src/pipecat/services/openai/tts.py | 8 +-- src/pipecat/services/piper/tts.py | 16 ++--- src/pipecat/services/resembleai/tts.py | 18 ++--- src/pipecat/services/rime/tts.py | 42 +++++------- src/pipecat/services/sarvam/tts.py | 18 ++--- src/pipecat/services/speechmatics/tts.py | 13 +--- src/pipecat/services/xtts/tts.py | 12 +--- 26 files changed, 182 insertions(+), 382 deletions(-) diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index 3b1296752..56a7d0e82 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -23,12 +23,11 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import TTSSettings, _warn_deprecated_param -from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService +from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -80,7 +79,7 @@ class AsyncAITTSSettings(TTSSettings): pass -class AsyncAITTSService(AudioContextTTSService): +class AsyncAITTSService(WebsocketTTSService): """Async TTS service with WebSocket streaming. Provides text-to-speech using Async's streaming WebSocket API. @@ -183,8 +182,9 @@ class AsyncAITTSService(AudioContextTTSService): aggregate_sentences=aggregate_sentences, text_aggregation_mode=text_aggregation_mode, pause_frame_processing=True, - push_stop_frames=True, sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -340,13 +340,18 @@ class AsyncAITTSService(AudioContextTTSService): return self._websocket raise Exception("Websocket not connected") - async def flush_audio(self): - """Flush any pending audio.""" - context_id = self.get_active_audio_context_id() - if not context_id or not self._websocket: + async def flush_audio(self, context_id: Optional[str] = None): + """Flush any pending audio. + + Args: + context_id: The specific context to flush. If None, falls back to the + currently active context. + """ + flush_id = context_id or self.get_active_audio_context_id() + if not flush_id or not self._websocket: return logger.trace(f"{self}: flushing audio") - msg = self._build_msg(text=" ", context_id=context_id, force=True) + msg = self._build_msg(text=" ", context_id=flush_id, force=True) await self._websocket.send(msg) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): @@ -459,12 +464,6 @@ class AsyncAITTSService(AudioContextTTSService): await self._connect() try: - if not self.has_active_audio_context(): - await self.start_ttfb_metrics() - yield TTSStartedFrame(context_id=context_id) - if not self.audio_context_available(context_id): - await self.create_audio_context(context_id) - msg = self._build_msg(text=text, force=True, context_id=context_id) await self._get_websocket().send(msg) await self.start_tts_usage_metrics(text) @@ -574,6 +573,8 @@ class AsyncAIHttpTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -632,7 +633,7 @@ class AsyncAIHttpTTSService(TTSService): try: voice_config = {"mode": "id", "id": self._settings.voice} - await self.start_ttfb_metrics() + payload = { "model_id": self._settings.model, "transcript": text, @@ -644,7 +645,7 @@ class AsyncAIHttpTTSService(TTSService): }, "language": self._settings.language, } - yield TTSStartedFrame(context_id=context_id) + headers = { "version": self._api_version, "x-api-key": self._api_key, @@ -682,4 +683,3 @@ class AsyncAIHttpTTSService(TTSService): await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) diff --git a/src/pipecat/services/aws/tts.py b/src/pipecat/services/aws/tts.py index 043fc264e..285026bca 100644 --- a/src/pipecat/services/aws/tts.py +++ b/src/pipecat/services/aws/tts.py @@ -22,8 +22,6 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -247,6 +245,8 @@ class AWSPollyTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -329,8 +329,6 @@ class AWSPollyTTSService(TTSService): logger.debug(f"{self}: Generating TTS [{text}]") try: - await self.start_ttfb_metrics() - # Construct the parameters dictionary ssml = self._construct_ssml(text) @@ -362,8 +360,6 @@ class AWSPollyTTSService(TTSService): await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - CHUNK_SIZE = self.chunk_size for i in range(0, len(audio_data), CHUNK_SIZE): @@ -373,14 +369,10 @@ class AWSPollyTTSService(TTSService): frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id) yield frame - yield TTSStoppedFrame(context_id=context_id) except (BotoCoreError, ClientError) as error: error_message = f"AWS Polly TTS error: {str(error)}" yield ErrorFrame(error=error_message) - finally: - yield TTSStoppedFrame(context_id=context_id) - class PollyTTSService(AWSPollyTTSService): """Deprecated alias for AWSPollyTTSService. diff --git a/src/pipecat/services/azure/tts.py b/src/pipecat/services/azure/tts.py index 112459c5a..f710482e9 100644 --- a/src/pipecat/services/azure/tts.py +++ b/src/pipecat/services/azure/tts.py @@ -21,7 +21,6 @@ from pipecat.frames.frames import ( InterruptionFrame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection @@ -331,8 +330,8 @@ class AzureTTSService(TTSService, AzureBaseTTSService): text_aggregation_mode=text_aggregation_mode, push_text_frames=False, # We'll push text frames based on word timestamps push_stop_frames=True, + push_start_frame=True, pause_frame_processing=True, - supports_word_timestamps=True, sample_rate=sample_rate, settings=default_settings, **kwargs, @@ -346,7 +345,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService): self._audio_queue = asyncio.Queue() self._word_boundary_queue = asyncio.Queue() self._word_processor_task = None - self._first_chunk = True self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds self._current_sentence_base_offset: float = 0.0 # Base offset for current sentence self._current_sentence_duration: float = 0.0 # Duration from Azure callback @@ -619,7 +617,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService): def _reset_state(self): """Reset TTS state between turns.""" - self._first_chunk = True self._cumulative_audio_offset = 0.0 self._current_sentence_base_offset = 0.0 self._current_sentence_duration = 0.0 @@ -628,7 +625,7 @@ class AzureTTSService(TTSService, AzureBaseTTSService): self._last_timestamp = None self._current_context_id = None - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio data.""" logger.trace(f"{self}: flushing audio") @@ -694,9 +691,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService): return try: - await self.start_ttfb_metrics() - yield TTSStartedFrame(context_id=context_id) - self._first_chunk = True self._current_context_id = context_id # Capture base offset BEFORE starting synthesis to avoid race conditions @@ -719,11 +713,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService): yield ErrorFrame(error=str(chunk)) break - if self._first_chunk: - await self.stop_ttfb_metrics() - await self.start_word_timestamps() - self._first_chunk = False - frame = TTSAudioRawFrame( audio=chunk, sample_rate=self.sample_rate, @@ -833,6 +822,8 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -887,8 +878,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService): """ logger.debug(f"{self}: Generating TTS [{text}]") - await self.start_ttfb_metrics() - ssml = self._construct_ssml(text) result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml) @@ -896,7 +885,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService): if result.reason == ResultReason.SynthesizingAudioCompleted: await self.start_tts_usage_metrics(text) await self.stop_ttfb_metrics() - yield TTSStartedFrame(context_id=context_id) # Azure always sends a 44-byte header. Strip it off. yield TTSAudioRawFrame( audio=result.audio_data[44:], @@ -904,7 +892,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService): num_channels=1, context_id=context_id, ) - yield TTSStoppedFrame(context_id=context_id) elif result.reason == ResultReason.Canceled: cancellation_details = result.cancellation_details logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}") diff --git a/src/pipecat/services/camb/tts.py b/src/pipecat/services/camb/tts.py index ef007181e..b30726fda 100644 --- a/src/pipecat/services/camb/tts.py +++ b/src/pipecat/services/camb/tts.py @@ -29,8 +29,6 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -271,6 +269,8 @@ class CambTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -332,8 +332,6 @@ class CambTTSService(TTSService): text = text[:3000] try: - await self.start_ttfb_metrics() - # Build SDK parameters tts_kwargs: Dict[str, Any] = { "text": text, @@ -348,7 +346,6 @@ class CambTTSService(TTSService): tts_kwargs["user_instructions"] = self._settings.user_instructions await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) assert self._client is not None, "Camb.ai TTS service not initialized" @@ -384,5 +381,3 @@ class CambTTSService(TTSService): except Exception as e: yield ErrorFrame(error=f"Camb.ai TTS error: {e}") - finally: - yield TTSStoppedFrame(context_id=context_id) diff --git a/src/pipecat/services/cartesia/tts.py b/src/pipecat/services/cartesia/tts.py index 166aa70af..3f708d06f 100644 --- a/src/pipecat/services/cartesia/tts.py +++ b/src/pipecat/services/cartesia/tts.py @@ -27,7 +27,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param -from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService +from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.text.base_text_aggregator import BaseTextAggregator from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator @@ -203,7 +203,7 @@ class CartesiaTTSSettings(TTSSettings): pronunciation_dict_id: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) -class CartesiaTTSService(AudioContextTTSService): +class CartesiaTTSService(WebsocketTTSService): """Cartesia TTS service with WebSocket streaming and word timestamps. Provides text-to-speech using Cartesia's streaming WebSocket API. @@ -334,9 +334,9 @@ class CartesiaTTSService(AudioContextTTSService): text_aggregation_mode=text_aggregation_mode, aggregate_sentences=aggregate_sentences, push_text_frames=False, - pause_frame_processing=True, - supports_word_timestamps=True, + pause_frame_processing=False, sample_rate=sample_rate, + push_start_frame=True, text_aggregator=text_aggregator, settings=default_settings, **kwargs, @@ -452,7 +452,11 @@ class CartesiaTTSService(AudioContextTTSService): return list(zip(words, starts)) def _build_msg( - self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True + self, + text: str = "", + continue_transcript: bool = True, + add_timestamps: bool = True, + context_id: str = "", ): voice_config = {} voice_config["mode"] = "id" @@ -461,7 +465,7 @@ class CartesiaTTSService(AudioContextTTSService): msg = { "transcript": text, "continue": continue_transcript, - "context_id": self.get_active_audio_context_id(), + "context_id": context_id, "model_id": self._settings.model, "voice": voice_config, "output_format": { @@ -580,15 +584,19 @@ class CartesiaTTSService(AudioContextTTSService): """ pass - async def flush_audio(self): - """Flush any pending audio and finalize the current context.""" - context_id = self.get_active_audio_context_id() - if not context_id or not self._websocket: + async def flush_audio(self, context_id: Optional[str] = None): + """Flush any pending audio and finalize the current context. + + Args: + context_id: The specific context to flush. If None, falls back to the + currently active context. + """ + flush_id = context_id or self.get_active_audio_context_id() + if not flush_id or not self._websocket: return logger.trace(f"{self}: flushing audio") - msg = self._build_msg(text="", continue_transcript=False) + msg = self._build_msg(text="", continue_transcript=False, context_id=flush_id) await self._websocket.send(msg) - self.reset_active_audio_context() async def _process_messages(self): async for message in self._get_websocket(): @@ -607,8 +615,6 @@ class CartesiaTTSService(AudioContextTTSService): ) await self.add_word_timestamps(processed_timestamps, ctx_id) elif msg["type"] == "chunk": - await self.stop_ttfb_metrics() - await self.start_word_timestamps() frame = TTSAudioRawFrame( audio=base64.b64decode(msg["data"]), sample_rate=self.sample_rate, @@ -652,12 +658,7 @@ class CartesiaTTSService(AudioContextTTSService): if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() - if not self.has_active_audio_context(): - await self.start_ttfb_metrics() - yield TTSStartedFrame(context_id=context_id) - await self.create_audio_context(context_id) - - msg = self._build_msg(text=text) + msg = self._build_msg(text=text, context_id=context_id) try: await self._get_websocket().send(msg) @@ -777,6 +778,8 @@ class CartesiaHttpTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -863,8 +866,6 @@ class CartesiaHttpTTSService(TTSService): try: voice_config = {"mode": "id", "id": self._settings.voice} - await self.start_ttfb_metrics() - output_format = { "container": self._output_container, "encoding": self._output_encoding, @@ -889,8 +890,6 @@ class CartesiaHttpTTSService(TTSService): if self._settings.pronunciation_dict_id: payload["pronunciation_dict_id"] = self._settings.pronunciation_dict_id - yield TTSStartedFrame(context_id=context_id) - headers = { "Cartesia-Version": self._cartesia_version, "X-API-Key": self._api_key, @@ -922,4 +921,3 @@ class CartesiaHttpTTSService(TTSService): yield ErrorFrame(error=f"Unknown error occurred: {e}") finally: await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) diff --git a/src/pipecat/services/deepgram/sagemaker/tts.py b/src/pipecat/services/deepgram/sagemaker/tts.py index 3693178a1..9e8c30ad7 100644 --- a/src/pipecat/services/deepgram/sagemaker/tts.py +++ b/src/pipecat/services/deepgram/sagemaker/tts.py @@ -328,7 +328,7 @@ class DeepgramSageMakerTTSService(TTSService): except Exception as e: logger.error(f"{self} error sending Clear message: {e}") - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis by sending Flush command. This should be called when the LLM finishes a complete response to force @@ -355,12 +355,12 @@ class DeepgramSageMakerTTSService(TTSService): logger.debug(f"{self}: Generating TTS [{text}]") try: - if not self._ttfb_started: - await self.start_ttfb_metrics() - self._ttfb_started = True - await self.start_tts_usage_metrics(text) - - yield TTSStartedFrame(context_id=context_id) + if not self.audio_context_available(context_id): + await self.create_audio_context(context_id) + if not self._ttfb_started: + await self.start_ttfb_metrics() + self._ttfb_started = True + yield TTSStartedFrame(context_id=context_id) self._context_id = context_id await self._client.send_json({"type": "Speak", "text": text}) diff --git a/src/pipecat/services/deepgram/tts.py b/src/pipecat/services/deepgram/tts.py index 95db998e2..6c8685bee 100644 --- a/src/pipecat/services/deepgram/tts.py +++ b/src/pipecat/services/deepgram/tts.py @@ -26,8 +26,6 @@ from pipecat.frames.frames import ( LLMFullResponseEndFrame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import TTSSettings, _warn_deprecated_param @@ -120,6 +118,7 @@ class DeepgramTTSService(WebsocketTTSService): sample_rate=sample_rate, pause_frame_processing=True, push_stop_frames=True, + push_start_frame=True, append_trailing_space=True, settings=default_settings, **kwargs, @@ -130,7 +129,6 @@ class DeepgramTTSService(WebsocketTTSService): self._encoding = encoding self._receive_task = None - self._context_id: Optional[str] = None def can_generate_metrics(self) -> bool: """Check if the service can generate metrics. @@ -267,7 +265,6 @@ class DeepgramTTSService(WebsocketTTSService): logger.error(f"{self} exception: {e}") await self.push_error(ErrorFrame(error=f"{self} error: {e}")) finally: - self._context_id = None self._websocket = None await self._call_event_handler("on_disconnected") @@ -299,7 +296,9 @@ class DeepgramTTSService(WebsocketTTSService): if isinstance(message, bytes): # Binary message contains audio data await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame(message, self.sample_rate, 1, context_id=self._context_id) + frame = TTSAudioRawFrame( + message, self.sample_rate, 1, context_id=self.get_active_audio_context_id() + ) await self.push_frame(frame) elif isinstance(message, str): # Text message contains metadata or control messages @@ -326,7 +325,7 @@ class DeepgramTTSService(WebsocketTTSService): except json.JSONDecodeError: logger.error(f"Invalid JSON message: {message}") - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis by sending Flush command. This should be called when the LLM finishes a complete response to force @@ -357,13 +356,8 @@ class DeepgramTTSService(WebsocketTTSService): if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() - await self.start_ttfb_metrics() await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - # Store context_id for use in _receive_messages - self._context_id = context_id - # Send text message to Deepgram # Note: We don't send Flush here - that should only be sent when the # LLM finishes a complete response via flush_audio() @@ -435,6 +429,8 @@ class DeepgramHttpTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -492,7 +488,6 @@ class DeepgramHttpTTSService(TTSService): raise Exception(f"HTTP {response.status}: {error_text}") await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) CHUNK_SIZE = self.chunk_size @@ -510,7 +505,5 @@ class DeepgramHttpTTSService(TTSService): context_id=context_id, ) - yield TTSStoppedFrame(context_id=context_id) - except Exception as e: yield ErrorFrame(f"Error getting audio: {str(e)}") diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index cfb08eb6d..de930d1f2 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -46,9 +46,9 @@ from pipecat.frames.frames import ( from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import ( - AudioContextTTSService, TextAggregationMode, TTSService, + WebsocketTTSService, ) from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -308,7 +308,7 @@ def calculate_word_times( return (word_times, new_partial_word, new_partial_word_start_time) -class ElevenLabsTTSService(AudioContextTTSService): +class ElevenLabsTTSService(WebsocketTTSService): """ElevenLabs WebSocket-based TTS service with word timestamps. Provides real-time text-to-speech using ElevenLabs' WebSocket streaming API. @@ -479,7 +479,6 @@ class ElevenLabsTTSService(AudioContextTTSService): push_text_frames=False, push_stop_frames=True, pause_frame_processing=True, - supports_word_timestamps=True, sample_rate=sample_rate, settings=default_settings, **kwargs, @@ -559,20 +558,15 @@ class ElevenLabsTTSService(AudioContextTTSService): ) await self._disconnect() await self._connect() - elif voice_settings_changed and self.has_active_audio_context(): + elif voice_settings_changed: logger.debug( f"Voice settings changed ({changed.keys() & ElevenLabsTTSSettings.VOICE_SETTINGS_FIELDS}), " f"closing current context to apply changes" ) - context_id = self.get_active_audio_context_id() - try: - if self._websocket: - await self._websocket.send( - json.dumps({"context_id": context_id, "close_context": True}) - ) - except Exception as e: - await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) - self.reset_active_audio_context() + audio_contexts = self.get_audio_contexts() + if audio_contexts: + for ctx_id in audio_contexts: + await self._close_context(ctx_id) if not url_changed: # Reconnect applies all settings; only warn about fields not handled @@ -610,13 +604,18 @@ class ElevenLabsTTSService(AudioContextTTSService): await super().cancel(frame) await self._disconnect() - async def flush_audio(self): - """Flush any pending audio and finalize the current context.""" - context_id = self.get_active_audio_context_id() - if not context_id or not self._websocket: + async def flush_audio(self, context_id: Optional[str] = None): + """Flush any pending audio and finalize the current context. + + Args: + context_id: The specific context to flush. If None, falls back to the + currently active context. + """ + flush_id = context_id or self.get_active_audio_context_id() + if not flush_id or not self._websocket: return logger.trace(f"{self}: flushing audio") - msg = {"context_id": context_id, "flush": True} + msg = {"context_id": flush_id, "flush": True} await self._websocket.send(json.dumps(msg)) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): @@ -703,9 +702,7 @@ class ElevenLabsTTSService(AudioContextTTSService): if self._websocket: logger.debug("Disconnecting from ElevenLabs") - # Close all contexts and the socket - if self.has_active_audio_context(): - await self._websocket.send(json.dumps({"close_socket": True})) + await self._websocket.send(json.dumps({"close_socket": True})) await self._websocket.close() logger.debug("Disconnected from ElevenLabs") except Exception as e: @@ -737,6 +734,7 @@ class ElevenLabsTTSService(AudioContextTTSService): ) except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) + self._cumulative_time = 0.0 self._partial_word = "" self._partial_word_start_time = 0.0 @@ -782,9 +780,6 @@ class ElevenLabsTTSService(AudioContextTTSService): continue if msg.get("audio"): - await self.stop_ttfb_metrics() - await self.start_word_timestamps() - audio = base64.b64decode(msg["audio"]) frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=received_ctx_id) await self.append_to_audio_context(received_ctx_id, frame) @@ -845,9 +840,8 @@ class ElevenLabsTTSService(AudioContextTTSService): logger.warning(f"{self} keepalive error: {e}") break - async def _send_text(self, text: str): + async def _send_text(self, text: str, context_id: str): """Send text to the WebSocket for synthesis.""" - context_id = self.get_active_audio_context_id() if self._websocket and context_id: msg = {"text": text, "context_id": context_id} await self._websocket.send(json.dumps(msg)) @@ -870,16 +864,14 @@ class ElevenLabsTTSService(AudioContextTTSService): await self._connect() try: - if not self.has_active_audio_context(): + if not self.audio_context_available(context_id): + await self.create_audio_context(context_id) await self.start_ttfb_metrics() yield TTSStartedFrame(context_id=context_id) self._cumulative_time = 0 self._partial_word = "" self._partial_word_start_time = 0.0 - if not self.audio_context_available(context_id): - await self.create_audio_context(context_id) - # Initialize context with voice settings and pronunciation dictionaries msg = {"text": " ", "context_id": context_id} if self._voice_settings: @@ -892,7 +884,7 @@ class ElevenLabsTTSService(AudioContextTTSService): await self._websocket.send(json.dumps(msg)) logger.trace(f"Created new context {context_id}") - await self._send_text(text) + await self._send_text(text, context_id) await self.start_tts_usage_metrics(text) except Exception as e: yield TTSStoppedFrame(context_id=context_id) @@ -1046,7 +1038,7 @@ class ElevenLabsHttpTTSService(TTSService): aggregate_sentences=aggregate_sentences, push_text_frames=False, push_stop_frames=True, - supports_word_timestamps=True, + push_start_frame=True, sample_rate=sample_rate, settings=default_settings, **kwargs, @@ -1266,8 +1258,6 @@ class ElevenLabsHttpTTSService(TTSService): params["optimize_streaming_latency"] = self._settings.optimize_streaming_latency try: - await self.start_ttfb_metrics() - async with self._session.post( url, json=payload, headers=headers, params=params ) as response: @@ -1278,10 +1268,6 @@ class ElevenLabsHttpTTSService(TTSService): await self.start_tts_usage_metrics(text) - # Start TTS sequence - await self.start_word_timestamps() - yield TTSStartedFrame(context_id=context_id) - # Track the duration of this utterance based on the last character's end time utterance_duration = 0 async for line in response.content: @@ -1347,4 +1333,3 @@ class ElevenLabsHttpTTSService(TTSService): yield ErrorFrame(error=f"Unknown error occurred: {e}") finally: await self.stop_ttfb_metrics() - # Let the parent class handle TTSStoppedFrame diff --git a/src/pipecat/services/fish/tts.py b/src/pipecat/services/fish/tts.py index 9ea749546..64c3bccd9 100644 --- a/src/pipecat/services/fish/tts.py +++ b/src/pipecat/services/fish/tts.py @@ -209,6 +209,7 @@ class FishAudioTTSService(InterruptibleTTSService): super().__init__( push_stop_frames=True, + push_start_frame=True, pause_frame_processing=True, sample_rate=sample_rate, settings=default_settings, @@ -219,7 +220,6 @@ class FishAudioTTSService(InterruptibleTTSService): self._base_url = "wss://api.fish.audio/v1/tts/live" self._websocket = None self._receive_task = None - self._request_id = None # Init-only audio format config (not runtime-updatable). self._fish_sample_rate = 0 # Set in start() @@ -341,11 +341,10 @@ class FishAudioTTSService(InterruptibleTTSService): except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: - self._request_id = None self._websocket = None await self._call_event_handler("on_disconnected") - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any buffered audio by sending a flush event to Fish Audio.""" logger.trace(f"{self}: Flushing audio buffers") if not self._websocket or self._websocket.state is State.CLOSED: @@ -361,7 +360,6 @@ class FishAudioTTSService(InterruptibleTTSService): async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): await super()._handle_interruption(frame, direction) await self.stop_all_metrics() - self._request_id = None async def _receive_messages(self): async for message in self._get_websocket(): @@ -398,12 +396,6 @@ class FishAudioTTSService(InterruptibleTTSService): if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() - if not self._request_id: - await self.start_ttfb_metrics() - await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - self._request_id = str(uuid.uuid4()) - # Send the text text_message = { "event": "text", diff --git a/src/pipecat/services/google/tts.py b/src/pipecat/services/google/tts.py index 7cf2ee996..071d731b1 100644 --- a/src/pipecat/services/google/tts.py +++ b/src/pipecat/services/google/tts.py @@ -34,8 +34,6 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import ( NOT_GIVEN, @@ -655,6 +653,8 @@ class GoogleHttpTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -803,8 +803,6 @@ class GoogleHttpTTSService(TTSService): logger.debug(f"{self}: Generating TTS [{text}]") try: - await self.start_ttfb_metrics() - # Check if the voice is a Chirp voice (including Chirp 3) or Journey voice is_chirp_voice = "chirp" in self._settings.voice.lower() is_journey_voice = "journey" in self._settings.voice.lower() @@ -840,8 +838,6 @@ class GoogleHttpTTSService(TTSService): await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - # Skip the first 44 bytes to remove the WAV header audio_content = response.audio_content[44:] @@ -855,8 +851,6 @@ class GoogleHttpTTSService(TTSService): frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id) yield frame - yield TTSStoppedFrame(context_id=context_id) - except Exception as e: error_message = f"TTS generation error: {str(e)}" yield ErrorFrame(error=error_message) @@ -967,8 +961,6 @@ class GoogleBaseTTSService(TTSService): streaming_responses = await self._client.streaming_synthesize(request_generator()) await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - audio_buffer = b"" first_chunk_for_ttfb = False @@ -992,8 +984,6 @@ class GoogleBaseTTSService(TTSService): if audio_buffer: yield TTSAudioRawFrame(audio_buffer, self.sample_rate, 1, context_id=context_id) - yield TTSStoppedFrame(context_id=context_id) - class GoogleTTSService(GoogleBaseTTSService): """Google Cloud Text-to-Speech streaming service. @@ -1096,6 +1086,8 @@ class GoogleTTSService(GoogleBaseTTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -1135,8 +1127,6 @@ class GoogleTTSService(GoogleBaseTTSService): logger.debug(f"{self}: Generating TTS [{text}]") try: - await self.start_ttfb_metrics() - # Build voice selection params if self._voice_cloning_key: voice_clone_params = texttospeech_v1.VoiceCloneParams( @@ -1352,6 +1342,8 @@ class GeminiTTSService(GoogleBaseTTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -1414,8 +1406,6 @@ class GeminiTTSService(GoogleBaseTTSService): logger.debug(f"{self}: Generating TTS [{text}]") try: - await self.start_ttfb_metrics() - # Build voice selection params if self._settings.multi_speaker and self._settings.speaker_configs: # Multi-speaker mode diff --git a/src/pipecat/services/gradium/tts.py b/src/pipecat/services/gradium/tts.py index 2ade14367..745a77f56 100644 --- a/src/pipecat/services/gradium/tts.py +++ b/src/pipecat/services/gradium/tts.py @@ -19,11 +19,10 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, TTSStoppedFrame, ) from pipecat.services.settings import TTSSettings, _warn_deprecated_param -from pipecat.services.tts_service import AudioContextTTSService +from pipecat.services.tts_service import WebsocketTTSService from pipecat.utils.tracing.service_decorators import traced_tts try: @@ -45,7 +44,7 @@ class GradiumTTSSettings(TTSSettings): pass -class GradiumTTSService(AudioContextTTSService): +class GradiumTTSService(WebsocketTTSService): """Text-to-Speech service using Gradium's websocket API.""" _settings: GradiumTTSSettings @@ -125,9 +124,9 @@ class GradiumTTSService(AudioContextTTSService): super().__init__( push_stop_frames=True, + push_start_frame=True, push_text_frames=False, pause_frame_processing=True, - supports_word_timestamps=True, sample_rate=SAMPLE_RATE, settings=default_settings, **kwargs, @@ -166,12 +165,9 @@ class GradiumTTSService(AudioContextTTSService): self._warn_unhandled_updated_settings(changed) return changed - def _build_msg(self, text: str = "") -> dict: + def _build_msg(self, text: str = "", context_id: str = "") -> dict: """Build JSON message for Gradium API.""" - msg = {"text": text, "type": "text"} - context_id = self.get_active_audio_context_id() - if context_id: - msg["client_req_id"] = context_id + msg = {"text": text, "type": "text", "client_req_id": context_id} return msg async def start(self, frame: StartFrame): @@ -280,15 +276,14 @@ class GradiumTTSService(AudioContextTTSService): return self._websocket raise Exception("Websocket not connected") - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis.""" - context_id = self.get_active_audio_context_id() - if not context_id or not self._websocket: + flush_id = context_id or self.get_active_audio_context_id() + if not flush_id or not self._websocket: return try: - msg = {"type": "end_of_stream", "client_req_id": context_id} + msg = {"type": "end_of_stream", "client_req_id": flush_id} await self._websocket.send(json.dumps(msg)) - self.reset_active_audio_context() except ConnectionClosedOK: logger.debug(f"{self}: connection closed normally during flush") except Exception as e: @@ -326,8 +321,6 @@ class GradiumTTSService(AudioContextTTSService): if msg["type"] == "audio": if not ctx_id or not self.audio_context_available(ctx_id): continue - await self.stop_ttfb_metrics() - await self.start_word_timestamps() frame = TTSAudioRawFrame( audio=base64.b64decode(msg["audio"]), sample_rate=self.sample_rate, @@ -369,12 +362,7 @@ class GradiumTTSService(AudioContextTTSService): await self._connect() try: - if not self.has_active_audio_context(): - await self.start_ttfb_metrics() - yield TTSStartedFrame(context_id=context_id) - await self.create_audio_context(context_id) - - msg = self._build_msg(text=text) + msg = self._build_msg(text=text, context_id=context_id) await self._get_websocket().send(json.dumps(msg)) await self.start_tts_usage_metrics(text) except Exception as e: diff --git a/src/pipecat/services/groq/tts.py b/src/pipecat/services/groq/tts.py index 18b623fc8..139816834 100644 --- a/src/pipecat/services/groq/tts.py +++ b/src/pipecat/services/groq/tts.py @@ -18,8 +18,6 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -140,6 +138,8 @@ class GroqTTSService(TTSService): super().__init__( pause_frame_processing=True, + push_start_frame=True, + push_stop_frames=True, sample_rate=sample_rate, settings=default_settings, **kwargs, @@ -171,9 +171,6 @@ class GroqTTSService(TTSService): """ logger.debug(f"{self}: Generating TTS [{text}]") measuring_ttfb = True - await self.start_ttfb_metrics() - yield TTSStartedFrame(context_id=context_id) - try: response = await self._client.audio.speech.create( model=self._settings.model, @@ -198,5 +195,3 @@ class GroqTTSService(TTSService): yield TTSAudioRawFrame(bytes, frame_rate, channels, context_id=context_id) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}") - - yield TTSStoppedFrame(context_id=context_id) diff --git a/src/pipecat/services/hume/tts.py b/src/pipecat/services/hume/tts.py index 052d1cd0a..ff5eb7522 100644 --- a/src/pipecat/services/hume/tts.py +++ b/src/pipecat/services/hume/tts.py @@ -22,7 +22,6 @@ from pipecat.frames.frames import ( InterruptionFrame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection @@ -166,7 +165,7 @@ class HumeTTSService(TTSService): sample_rate=sample_rate, push_text_frames=False, push_stop_frames=True, - supports_word_timestamps=True, + push_start_frame=True, settings=default_settings, **kwargs, ) @@ -181,7 +180,6 @@ class HumeTTSService(TTSService): # Track cumulative time for word timestamps across utterances self._cumulative_time = 0.0 - self._started = False def can_generate_metrics(self) -> bool: """Can generate metrics. @@ -203,7 +201,6 @@ class HumeTTSService(TTSService): def _reset_state(self): """Reset internal state variables.""" self._cumulative_time = 0.0 - self._started = False async def stop(self, frame: EndFrame) -> None: """Stop the service and cleanup resources. @@ -310,15 +307,8 @@ class HumeTTSService(TTSService): # Request raw PCM chunks in the streaming JSON pcm_fmt = FormatPcm(type="pcm") - await self.start_ttfb_metrics() await self.start_tts_usage_metrics(text) - # Start TTS sequence if not already started - if not self._started: - await self.start_word_timestamps() - yield TTSStartedFrame(context_id=context_id) - self._started = True - try: # Instant mode is always enabled here (not user-configurable) # Hume emits mono PCM at 48 kHz; downstream can resample if needed. @@ -395,4 +385,3 @@ class HumeTTSService(TTSService): finally: # Ensure TTFB timer is stopped even on early failures await self.stop_ttfb_metrics() - # Let the parent class handle TTSStoppedFrame via push_stop_frames diff --git a/src/pipecat/services/inworld/tts.py b/src/pipecat/services/inworld/tts.py index a1421b2ab..d602efb52 100644 --- a/src/pipecat/services/inworld/tts.py +++ b/src/pipecat/services/inworld/tts.py @@ -62,7 +62,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService +from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService from pipecat.utils.tracing.service_decorators import traced_tts @@ -212,7 +212,7 @@ class InworldHttpTTSService(TTSService): super().__init__( push_text_frames=False, push_stop_frames=True, - supports_word_timestamps=True, + push_start_frame=True, sample_rate=sample_rate, settings=default_settings, **kwargs, @@ -359,11 +359,6 @@ class InworldHttpTTSService(TTSService): } try: - await self.start_ttfb_metrics() - - await self.start_word_timestamps() - yield TTSStartedFrame(context_id=context_id) - async with self._session.post( self._base_url, json=payload, headers=headers ) as response: @@ -514,7 +509,7 @@ class InworldHttpTTSService(TTSService): ) -class InworldTTSService(AudioContextTTSService): +class InworldTTSService(WebsocketTTSService): """Inworld AI WebSocket-based TTS service. Uses bidirectional WebSocket for lower latency streaming. Supports multiple @@ -650,7 +645,6 @@ class InworldTTSService(AudioContextTTSService): push_text_frames=False, push_stop_frames=True, pause_frame_processing=True, - supports_word_timestamps=True, sample_rate=sample_rate, aggregate_sentences=aggregate_sentences, text_aggregation_mode=text_aggregation_mode, @@ -719,17 +713,17 @@ class InworldTTSService(AudioContextTTSService): await super().cancel(frame) await self._disconnect() - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio without closing the context. This triggers synthesis of all accumulated text in the buffer while keeping the context open for subsequent text. The context is only closed on interruption, disconnect, or end of session. """ - context_id = self.get_active_audio_context_id() - if context_id and self._websocket: - logger.trace(f"Flushing audio for context {context_id}") - await self._send_flush(context_id) + flush_id = context_id or self.get_active_audio_context_id() + if flush_id and self._websocket: + logger.trace(f"Flushing audio for context {flush_id}") + await self._send_flush(flush_id) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): """Push a frame and handle state changes. @@ -899,12 +893,10 @@ class InworldTTSService(AudioContextTTSService): if self._websocket: logger.debug("Disconnecting from Inworld WebSocket TTS") - context_id = self.get_active_audio_context_id() - if context_id: - try: - await self._send_close_context(context_id) - except Exception: - pass + audio_contexts = self.get_audio_contexts() + if audio_contexts: + for ctx_id in audio_contexts: + await self._send_close_context(ctx_id) await self._websocket.close() logger.debug("Disconnected from Inworld WebSocket TTS") except Exception as e: @@ -934,10 +926,7 @@ class InworldTTSService(AudioContextTTSService): for k in ["contextCreated", "audioChunk", "flushCompleted", "contextClosed"] if k in result ] - logger.debug( - f"{self}: Received message types={msg_types}, ctx_id={ctx_id}, " - f"current_ctx={self.get_active_audio_context_id()}, available={self.audio_context_available(ctx_id) if ctx_id else 'N/A'}" - ) + logger.debug(f"{self}: Received message types={msg_types}, ctx_id={ctx_id}") # Check for errors status = result.get("status", {}) @@ -948,9 +937,7 @@ class InworldTTSService(AudioContextTTSService): # Handle "Context not found" error (code 5) # This can happen when a keepalive message is sent but no context is available. if error_code == 5 and "not found" in error_msg.lower(): - logger.debug( - f"{self}: Context {ctx_id or self.get_active_audio_context_id()} not found." - ) + logger.debug(f"{self}: Context {ctx_id} not found.") continue # For other errors, push error frame @@ -961,17 +948,10 @@ class InworldTTSService(AudioContextTTSService): await self.push_error(error_msg=str(msg["error"])) continue - # Check if this message belongs to an available context. - # If the context isn't available but matches our current context ID, - # recreate it (handles race conditions during interruption recovery). + # If the context isn't available recreate it (handles race conditions during interruption recovery). if ctx_id and not self.audio_context_available(ctx_id): - if self.get_active_audio_context_id() == ctx_id: - logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}") - await self.create_audio_context(ctx_id) - else: - # This is a message from an old/closed context - skip it - logger.trace(f"{self}: Skipping message from unavailable context: {ctx_id}") - continue + logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}") + await self.create_audio_context(ctx_id) # Process audio chunk audio_chunk = result.get("audioChunk", {}) @@ -979,8 +959,6 @@ class InworldTTSService(AudioContextTTSService): if audio_b64: logger.trace(f"{self}: Processing audio chunk for context {ctx_id}") - await self.stop_ttfb_metrics() - await self.start_word_timestamps() audio = base64.b64decode(audio_b64) if len(audio) > 44 and audio.startswith(b"RIFF"): audio = audio[44:] @@ -1012,12 +990,8 @@ class InworldTTSService(AudioContextTTSService): if "contextClosed" in result: logger.trace(f"{self}: Context closed on server: {ctx_id}") await self.stop_ttfb_metrics() - # Only reset if this is our current context - if ctx_id == self.get_active_audio_context_id(): - self.reset_active_audio_context() - if ctx_id and self.audio_context_available(ctx_id): - await self.remove_audio_context(ctx_id) await self.add_word_timestamps([("TTSStoppedFrame", 0), ("Reset", 0)], ctx_id) + await self.remove_audio_context(ctx_id) async def _keepalive_task_handler(self): """Send periodic keepalive messages to maintain WebSocket connection.""" @@ -1128,10 +1102,10 @@ class InworldTTSService(AudioContextTTSService): await self._connect() try: - if not self.has_active_audio_context(): + if not self.audio_context_available(context_id): + await self.create_audio_context(context_id) await self.start_ttfb_metrics() yield TTSStartedFrame(context_id=context_id) - await self.create_audio_context(context_id) await self._send_context(context_id) await self._send_text(context_id, text) diff --git a/src/pipecat/services/kokoro/tts.py b/src/pipecat/services/kokoro/tts.py index 0646923f3..e69ef7a67 100644 --- a/src/pipecat/services/kokoro/tts.py +++ b/src/pipecat/services/kokoro/tts.py @@ -20,8 +20,6 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import TTSSettings, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -170,6 +168,8 @@ class KokoroTTSService(TTSService): default_settings.apply_update(settings) super().__init__( + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -212,9 +212,7 @@ class KokoroTTSService(TTSService): logger.debug(f"{self}: Generating TTS [{text}]") try: - await self.start_ttfb_metrics() await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) stream = self._kokoro.create_stream( text, voice=self._settings.voice, lang=self._settings.language, speed=1.0 @@ -238,4 +236,3 @@ class KokoroTTSService(TTSService): yield ErrorFrame(error=f"Unknown error occurred: {e}") finally: await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) diff --git a/src/pipecat/services/lmnt/tts.py b/src/pipecat/services/lmnt/tts.py index 9ee6d8d60..c8bfcaf55 100644 --- a/src/pipecat/services/lmnt/tts.py +++ b/src/pipecat/services/lmnt/tts.py @@ -143,6 +143,7 @@ class LmntTTSService(InterruptibleTTSService): super().__init__( push_stop_frames=True, + push_start_frame=True, pause_frame_processing=True, sample_rate=sample_rate, settings=default_settings, @@ -152,7 +153,6 @@ class LmntTTSService(InterruptibleTTSService): self._api_key = api_key self._output_format = "raw" self._receive_task = None - self._context_id: Optional[str] = None def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. @@ -289,7 +289,6 @@ class LmntTTSService(InterruptibleTTSService): except Exception as e: await self.push_error(error_msg=f"Error disconnecting from LMNT: {e}", exception=e) finally: - self._context_id = None self._websocket = None await self._call_event_handler("on_disconnected") @@ -299,7 +298,7 @@ class LmntTTSService(InterruptibleTTSService): return self._websocket raise Exception("Websocket not connected") - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis.""" if not self._websocket or self._websocket.state is State.CLOSED: return @@ -315,7 +314,7 @@ class LmntTTSService(InterruptibleTTSService): audio=message, sample_rate=self.sample_rate, num_channels=1, - context_id=self._context_id, + context_id=self.get_active_audio_context_id(), ) await self.push_frame(frame) else: @@ -347,11 +346,6 @@ class LmntTTSService(InterruptibleTTSService): await self._connect() try: - await self.start_ttfb_metrics() - # Store context_id for use in _receive_messages - self._context_id = context_id - yield TTSStartedFrame(context_id=context_id) - # Send text to LMNT await self._get_websocket().send(json.dumps({"text": text})) # Force synthesis diff --git a/src/pipecat/services/minimax/tts.py b/src/pipecat/services/minimax/tts.py index efe8c1fd9..33e0669e1 100644 --- a/src/pipecat/services/minimax/tts.py +++ b/src/pipecat/services/minimax/tts.py @@ -23,8 +23,6 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -305,6 +303,8 @@ class MiniMaxHttpTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -402,8 +402,6 @@ class MiniMaxHttpTTSService(TTSService): payload["language_boost"] = self._settings.language_boost try: - await self.start_ttfb_metrics() - async with self._session.post( self._base_url, headers=headers, json=payload ) as response: @@ -413,7 +411,6 @@ class MiniMaxHttpTTSService(TTSService): return await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) # Process the streaming response buffer = bytearray() @@ -490,4 +487,3 @@ class MiniMaxHttpTTSService(TTSService): yield ErrorFrame(error=f"Unknown error occurred: {e}", exception=e) finally: await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) diff --git a/src/pipecat/services/neuphonic/tts.py b/src/pipecat/services/neuphonic/tts.py index da14ec2e5..c1916414e 100644 --- a/src/pipecat/services/neuphonic/tts.py +++ b/src/pipecat/services/neuphonic/tts.py @@ -180,6 +180,7 @@ class NeuphonicTTSService(InterruptibleTTSService): aggregate_sentences=aggregate_sentences, text_aggregation_mode=text_aggregation_mode, push_stop_frames=True, + push_start_frame=True, stop_frame_timeout_s=2.0, sample_rate=sample_rate, settings=default_settings, @@ -188,12 +189,8 @@ class NeuphonicTTSService(InterruptibleTTSService): self._api_key = api_key self._url = url - - self._cumulative_time = 0 - self._receive_task = None self._keepalive_task = None - self._context_id: Optional[str] = None self._encoding = encoding self._sampling_rate = sample_rate @@ -252,7 +249,7 @@ class NeuphonicTTSService(InterruptibleTTSService): await super().cancel(frame) await self._disconnect() - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis by sending stop command.""" if self._websocket: msg = {"text": ""} @@ -358,7 +355,6 @@ class NeuphonicTTSService(InterruptibleTTSService): except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: - self._context_id = None self._websocket = None await self._call_event_handler("on_disconnected") @@ -372,7 +368,7 @@ class NeuphonicTTSService(InterruptibleTTSService): audio = base64.b64decode(msg["data"]["audio"]) frame = TTSAudioRawFrame( - audio, self.sample_rate, 1, context_id=self._context_id + audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id() ) await self.push_frame(frame) @@ -415,12 +411,6 @@ class NeuphonicTTSService(InterruptibleTTSService): await self._connect() try: - await self.start_ttfb_metrics() - # Store context_id for use in _receive_messages - self._context_id = context_id - yield TTSStartedFrame(context_id=context_id) - self._cumulative_time = 0 - await self._send_text(text) await self.start_tts_usage_metrics(text) except Exception as e: @@ -523,6 +513,8 @@ class NeuphonicHttpTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_stop_frames=True, + push_start_frame=True, settings=default_settings, **kwargs, ) @@ -559,7 +551,7 @@ class NeuphonicHttpTTSService(TTSService): """ await super().start(frame) - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis. Note: @@ -633,8 +625,6 @@ class NeuphonicHttpTTSService(TTSService): payload["voice_id"] = self._settings.voice try: - await self.start_ttfb_metrics() - async with self._session.post(url, json=payload, headers=headers) as response: if response.status != 200: error_text = await response.text() @@ -643,7 +633,6 @@ class NeuphonicHttpTTSService(TTSService): return await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) # Process SSE stream line by line async for line in response.content: @@ -681,4 +670,3 @@ class NeuphonicHttpTTSService(TTSService): yield ErrorFrame(error=f"Unknown error occurred: {e}") finally: await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) diff --git a/src/pipecat/services/nvidia/tts.py b/src/pipecat/services/nvidia/tts.py index 6e3298a75..7f7638f5f 100644 --- a/src/pipecat/services/nvidia/tts.py +++ b/src/pipecat/services/nvidia/tts.py @@ -28,8 +28,6 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -145,6 +143,8 @@ class NvidiaTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -271,9 +271,6 @@ class NvidiaTTSService(TTSService): assert self._service is not None, "TTS service not initialized" assert self._config is not None, "Synthesis configuration not created" - await self.start_ttfb_metrics() - yield TTSStartedFrame(context_id=context_id) - logger.debug(f"{self}: Generating TTS [{text}]") responses = await asyncio.to_thread(read_audio_responses) @@ -289,7 +286,6 @@ class NvidiaTTSService(TTSService): yield frame await self.start_tts_usage_metrics(text) - yield TTSStoppedFrame(context_id=context_id) except asyncio.TimeoutError as e: logger.error(f"{self} timeout waiting for audio response") yield ErrorFrame(error=f"{self} error: {e}") diff --git a/src/pipecat/services/openai/tts.py b/src/pipecat/services/openai/tts.py index b4933ae9f..a50129349 100644 --- a/src/pipecat/services/openai/tts.py +++ b/src/pipecat/services/openai/tts.py @@ -22,8 +22,6 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -194,6 +192,8 @@ class OpenAITTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -234,8 +234,6 @@ class OpenAITTSService(TTSService): """ logger.debug(f"{self}: Generating TTS [{text}]") try: - await self.start_ttfb_metrics() - # Setup API parameters create_params = { "input": text, @@ -267,12 +265,10 @@ class OpenAITTSService(TTSService): CHUNK_SIZE = self.chunk_size - yield TTSStartedFrame(context_id=context_id) async for chunk in r.iter_bytes(CHUNK_SIZE): if len(chunk) > 0: await self.stop_ttfb_metrics() frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id) yield frame - yield TTSStoppedFrame(context_id=context_id) except BadRequestError as e: yield ErrorFrame(error=f"Unknown error occurred: {e}") diff --git a/src/pipecat/services/piper/tts.py b/src/pipecat/services/piper/tts.py index 46bf98cd1..f0343947b 100644 --- a/src/pipecat/services/piper/tts.py +++ b/src/pipecat/services/piper/tts.py @@ -17,8 +17,6 @@ from loguru import logger from pipecat.frames.frames import ( ErrorFrame, Frame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import TTSSettings, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -91,6 +89,8 @@ class PiperTTSService(TTSService): default_settings.apply_update(settings) super().__init__( + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -159,12 +159,8 @@ class PiperTTSService(TTSService): logger.debug(f"{self}: Generating TTS [{text}]") try: - await self.start_ttfb_metrics() - await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - async for frame in self._stream_audio_frames_from_iterator( async_iterator(self._voice.synthesize(text)), in_sample_rate=self._voice.config.sample_rate, @@ -178,7 +174,6 @@ class PiperTTSService(TTSService): finally: logger.debug(f"{self}: Finished TTS [{text}]") await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) # This assumes a running TTS service running: @@ -244,6 +239,8 @@ class PiperHttpTTSService(TTSService): default_settings.apply_update(settings) super().__init__( + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -279,8 +276,6 @@ class PiperHttpTTSService(TTSService): "Content-Type": "application/json", } try: - await self.start_ttfb_metrics() - data = { "text": text, "voice": self._settings.voice, @@ -296,8 +291,6 @@ class PiperHttpTTSService(TTSService): await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - CHUNK_SIZE = self.chunk_size async for frame in self._stream_audio_frames_from_iterator( @@ -311,4 +304,3 @@ class PiperHttpTTSService(TTSService): yield ErrorFrame(error=f"Unknown error occurred: {e}") finally: await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) diff --git a/src/pipecat/services/resembleai/tts.py b/src/pipecat/services/resembleai/tts.py index e1c1cf68a..45f8fc229 100644 --- a/src/pipecat/services/resembleai/tts.py +++ b/src/pipecat/services/resembleai/tts.py @@ -24,7 +24,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.services.settings import TTSSettings, _warn_deprecated_param -from pipecat.services.tts_service import AudioContextTTSService +from pipecat.services.tts_service import WebsocketTTSService from pipecat.utils.tracing.service_decorators import traced_tts try: @@ -43,7 +43,7 @@ class ResembleAITTSSettings(TTSSettings): pass -class ResembleAITTSService(AudioContextTTSService): +class ResembleAITTSService(WebsocketTTSService): """Resemble AI TTS service with WebSocket streaming and word timestamps. Provides text-to-speech using Resemble AI's streaming WebSocket API. @@ -103,7 +103,6 @@ class ResembleAITTSService(AudioContextTTSService): super().__init__( sample_rate=sample_rate, reuse_context_id_within_turn=False, - supports_word_timestamps=True, settings=default_settings, **kwargs, ) @@ -268,7 +267,7 @@ class ResembleAITTSService(AudioContextTTSService): """ pass - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio and finalize the current context.""" logger.trace(f"{self}: flushing audio") # For Resemble AI, we just wait for the audio_end message @@ -297,9 +296,6 @@ class ResembleAITTSService(AudioContextTTSService): continue if msg_type == "audio": - await self.stop_ttfb_metrics() - await self.start_word_timestamps() - # Decode base64 audio content audio_content = msg.get("audio_content", "") if not audio_content: @@ -447,14 +443,14 @@ class ResembleAITTSService(AudioContextTTSService): if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() - await self.start_ttfb_metrics() - yield TTSStartedFrame(context_id=context_id) + if not self.audio_context_available(context_id): + await self.create_audio_context(context_id) + await self.start_ttfb_metrics() + yield TTSStartedFrame(context_id=context_id) # Map request_id to context_id for tracking self._request_id_to_context[self._request_id_counter] = context_id - await self.create_audio_context(context_id) - msg = self._build_msg(text=text) try: diff --git a/src/pipecat/services/rime/tts.py b/src/pipecat/services/rime/tts.py index 04ebcd5fc..27580504b 100644 --- a/src/pipecat/services/rime/tts.py +++ b/src/pipecat/services/rime/tts.py @@ -33,10 +33,10 @@ from pipecat.frames.frames import ( from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import ( - AudioContextTTSService, InterruptibleTTSService, TextAggregationMode, TTSService, + WebsocketTTSService, ) from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.text.base_text_aggregator import BaseTextAggregator @@ -123,7 +123,7 @@ class RimeNonJsonTTSSettings(TTSSettings): _aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"} -class RimeTTSService(AudioContextTTSService): +class RimeTTSService(WebsocketTTSService): """Text-to-Speech service using Rime's websocket API. Uses Rime's websocket JSON API to convert text to speech with word-level timing @@ -276,7 +276,6 @@ class RimeTTSService(AudioContextTTSService): push_text_frames=False, push_stop_frames=True, pause_frame_processing=True, - supports_word_timestamps=True, append_trailing_space=True, sample_rate=sample_rate, settings=default_settings, @@ -408,9 +407,9 @@ class RimeTTSService(AudioContextTTSService): return changed - def _build_msg(self, text: str = "") -> dict: + def _build_msg(self, text: str = "", context_id: str = "") -> dict: """Build JSON message for Rime API.""" - msg = {"text": text, "contextId": self.get_active_audio_context_id()} + msg = {"text": text, "contextId": context_id} if self._extra_msg_fields: msg |= self._extra_msg_fields self._extra_msg_fields = {} @@ -557,15 +556,14 @@ class RimeTTSService(AudioContextTTSService): return word_pairs - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis.""" - context_id = self.get_active_audio_context_id() - if not context_id or not self._websocket: + flush_id = context_id or self.get_active_audio_context_id() + if not flush_id or not self._websocket: return logger.trace(f"{self}: flushing audio") await self._get_websocket().send(json.dumps({"operation": "flush"})) - self.reset_active_audio_context() async def _receive_messages(self): """Process incoming websocket messages.""" @@ -578,8 +576,6 @@ class RimeTTSService(AudioContextTTSService): context_id = msg["contextId"] if msg["type"] == "chunk": # Process audio chunk - await self.stop_ttfb_metrics() - await self.start_word_timestamps() frame = TTSAudioRawFrame( audio=base64.b64decode(msg["data"]), sample_rate=self.sample_rate, @@ -638,13 +634,13 @@ class RimeTTSService(AudioContextTTSService): await self._connect() try: - if not self.has_active_audio_context(): + if not self.audio_context_available(context_id): + await self.create_audio_context(context_id) await self.start_ttfb_metrics() yield TTSStartedFrame(context_id=context_id) self._cumulative_time = 0 - await self.create_audio_context(context_id) - msg = self._build_msg(text=text) + msg = self._build_msg(text=text, context_id=context_id) await self._get_websocket().send(json.dumps(msg)) await self.start_tts_usage_metrics(text) except Exception as e: @@ -773,6 +769,8 @@ class RimeHttpTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_stop_frames=True, + push_start_frame=True, settings=default_settings, **kwargs, ) @@ -844,8 +842,6 @@ class RimeHttpTTSService(TTSService): need_to_strip_wav_header = False try: - await self.start_ttfb_metrics() - async with self._session.post( self._base_url, json=payload, headers=headers ) as response: @@ -856,8 +852,6 @@ class RimeHttpTTSService(TTSService): await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - CHUNK_SIZE = self.chunk_size async for frame in self._stream_audio_frames_from_iterator( @@ -872,7 +866,6 @@ class RimeHttpTTSService(TTSService): yield ErrorFrame(error=f"Unknown error occurred: {e}") finally: await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) class RimeNonJsonTTSService(InterruptibleTTSService): @@ -1005,6 +998,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService): aggregate_sentences=aggregate_sentences, text_aggregation_mode=text_aggregation_mode, push_stop_frames=True, + push_start_frame=True, pause_frame_processing=True, append_trailing_space=True, settings=default_settings, @@ -1022,7 +1016,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService): self._settings.extra.update(params.extra) self._receive_task = None - self._context_id: Optional[str] = None def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. @@ -1138,7 +1131,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService): except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: - self._context_id = None self._websocket = None await self._call_event_handler("on_disconnected") @@ -1148,7 +1140,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService): return self._websocket raise Exception("Websocket not connected") - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis.""" if not self._websocket: return @@ -1168,7 +1160,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService): audio=message, sample_rate=self.sample_rate, num_channels=1, - context_id=self._context_id, + context_id=self.get_active_audio_context_id(), ) await self.push_frame(frame) except Exception as e: @@ -1190,10 +1182,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService): if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() try: - await self.start_ttfb_metrics() - # Store context_id for use in _receive_messages - self._context_id = context_id - yield TTSStartedFrame(context_id=context_id) # Send bare text (not JSON) await self._get_websocket().send(text) await self.start_tts_usage_metrics(text) diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py index 0d605fc6b..6bf38bb24 100644 --- a/src/pipecat/services/sarvam/tts.py +++ b/src/pipecat/services/sarvam/tts.py @@ -524,6 +524,8 @@ class SarvamHttpTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_stop_frames=True, + push_start_frame=True, settings=default_settings, **kwargs, ) @@ -573,8 +575,6 @@ class SarvamHttpTTSService(TTSService): logger.debug(f"{self}: Generating TTS [{text}]") try: - await self.start_ttfb_metrics() - # Build payload with common parameters payload = { "text": text, @@ -606,8 +606,6 @@ class SarvamHttpTTSService(TTSService): url = f"{self._base_url}/text-to-speech" - yield TTSStartedFrame(context_id=context_id) - async with self._session.post(url, json=payload, headers=headers) as response: if response.status != 200: error_text = await response.text() @@ -645,7 +643,6 @@ class SarvamHttpTTSService(TTSService): yield ErrorFrame(error=f"Error generating TTS: {e}", exception=e) finally: await self.stop_ttfb_metrics() - yield TTSStoppedFrame(context_id=context_id) class SarvamTTSService(InterruptibleTTSService): @@ -951,6 +948,7 @@ class SarvamTTSService(InterruptibleTTSService): push_text_frames=True, pause_frame_processing=True, push_stop_frames=True, + push_start_frame=True, sample_rate=sample_rate, settings=default_settings, **kwargs, @@ -967,7 +965,6 @@ class SarvamTTSService(InterruptibleTTSService): self._receive_task = None self._keepalive_task = None - self._context_id: Optional[str] = None def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. @@ -1018,7 +1015,7 @@ class SarvamTTSService(InterruptibleTTSService): await super().cancel(frame) await self._disconnect() - async def flush_audio(self): + async def flush_audio(self, context_id: Optional[str] = None): """Flush any pending audio synthesis by sending flush command.""" try: if self._websocket: @@ -1151,7 +1148,6 @@ class SarvamTTSService(InterruptibleTTSService): except Exception as e: await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e) finally: - self._context_id = None self._websocket = None await self._call_event_handler("on_disconnected") @@ -1170,7 +1166,7 @@ class SarvamTTSService(InterruptibleTTSService): await self.stop_ttfb_metrics() audio = base64.b64decode(msg["data"]["audio"]) frame = TTSAudioRawFrame( - audio, self.sample_rate, 1, context_id=self._context_id + audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id() ) await self.push_frame(frame) elif msg.get("type") == "error": @@ -1224,10 +1220,6 @@ class SarvamTTSService(InterruptibleTTSService): await self._connect() try: - await self.start_ttfb_metrics() - # Store context_id for use in _receive_messages - self._context_id = context_id - yield TTSStartedFrame(context_id=context_id) await self._send_text(text) await self.start_tts_usage_metrics(text) except Exception as e: diff --git a/src/pipecat/services/speechmatics/tts.py b/src/pipecat/services/speechmatics/tts.py index 55ded437e..22b47f3fc 100644 --- a/src/pipecat/services/speechmatics/tts.py +++ b/src/pipecat/services/speechmatics/tts.py @@ -19,8 +19,6 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -135,6 +133,8 @@ class SpeechmaticsTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -185,9 +185,6 @@ class SpeechmaticsTTSService(TTSService): url = _get_endpoint_url(self._base_url, self._settings.voice, self.sample_rate) try: - # Start TTS TTFB metrics - await self.start_ttfb_metrics() - # Track attempt attempt = 0 @@ -238,9 +235,6 @@ class SpeechmaticsTTSService(TTSService): # Update Pipecat metrics await self.start_tts_usage_metrics(text) - # Emit the TTS started frame - yield TTSStartedFrame(context_id=context_id) - # Process the response in streaming chunks first_chunk = True buffer = b"" @@ -277,8 +271,7 @@ class SpeechmaticsTTSService(TTSService): except Exception as e: yield ErrorFrame(error=f"Error generating TTS: {e}") finally: - # Emit the TTS stopped frame - yield TTSStoppedFrame(context_id=context_id) + await self.stop_ttfb_metrics() def _get_endpoint_url(base_url: str, voice: str, sample_rate: int) -> str: diff --git a/src/pipecat/services/xtts/tts.py b/src/pipecat/services/xtts/tts.py index ea7cb0b8b..539ddc88c 100644 --- a/src/pipecat/services/xtts/tts.py +++ b/src/pipecat/services/xtts/tts.py @@ -11,7 +11,7 @@ text-to-speech synthesis using local Docker deployment. """ from dataclasses import dataclass -from typing import AsyncGenerator, Dict, Optional +from typing import Any, AsyncGenerator, Dict, Optional import aiohttp from loguru import logger @@ -22,8 +22,6 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) from pipecat.services.settings import TTSSettings, _warn_deprecated_param from pipecat.services.tts_service import TTSService @@ -132,6 +130,8 @@ class XTTSService(TTSService): super().__init__( sample_rate=sample_rate, + push_start_frame=True, + push_stop_frames=True, settings=default_settings, **kwargs, ) @@ -213,8 +213,6 @@ class XTTSService(TTSService): "stream_chunk_size": 20, } - await self.start_ttfb_metrics() - async with self._aiohttp_session.post(url, json=payload) as r: if r.status != 200: text = await r.text() @@ -223,8 +221,6 @@ class XTTSService(TTSService): await self.start_tts_usage_metrics(text) - yield TTSStartedFrame(context_id=context_id) - CHUNK_SIZE = self.chunk_size buffer = bytearray() @@ -262,5 +258,3 @@ class XTTSService(TTSService): resampled_audio, self.sample_rate, 1, context_id=context_id ) yield frame - - yield TTSStoppedFrame(context_id=context_id)