Refactored all 25+ TTS service implementations to use the new push_start_frame=True pattern
This commit is contained in:
@@ -23,12 +23,11 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -80,7 +79,7 @@ class AsyncAITTSSettings(TTSSettings):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncAITTSService(AudioContextTTSService):
|
||||
class AsyncAITTSService(WebsocketTTSService):
|
||||
"""Async TTS service with WebSocket streaming.
|
||||
|
||||
Provides text-to-speech using Async's streaming WebSocket API.
|
||||
@@ -183,8 +182,9 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -340,13 +340,18 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio.
|
||||
|
||||
Args:
|
||||
context_id: The specific context to flush. If None, falls back to the
|
||||
currently active context.
|
||||
"""
|
||||
flush_id = context_id or self.get_active_audio_context_id()
|
||||
if not flush_id or not self._websocket:
|
||||
return
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
msg = self._build_msg(text=" ", context_id=context_id, force=True)
|
||||
msg = self._build_msg(text=" ", context_id=flush_id, force=True)
|
||||
await self._websocket.send(msg)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -459,12 +464,6 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
msg = self._build_msg(text=text, force=True, context_id=context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
@@ -574,6 +573,8 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -632,7 +633,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
|
||||
try:
|
||||
voice_config = {"mode": "id", "id": self._settings.voice}
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
payload = {
|
||||
"model_id": self._settings.model,
|
||||
"transcript": text,
|
||||
@@ -644,7 +645,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
},
|
||||
"language": self._settings.language,
|
||||
}
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
headers = {
|
||||
"version": self._api_version,
|
||||
"x-api-key": self._api_key,
|
||||
@@ -682,4 +683,3 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -247,6 +245,8 @@ class AWSPollyTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -329,8 +329,6 @@ class AWSPollyTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Construct the parameters dictionary
|
||||
ssml = self._construct_ssml(text)
|
||||
|
||||
@@ -362,8 +360,6 @@ class AWSPollyTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
for i in range(0, len(audio_data), CHUNK_SIZE):
|
||||
@@ -373,14 +369,10 @@ class AWSPollyTTSService(TTSService):
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
|
||||
yield frame
|
||||
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
except (BotoCoreError, ClientError) as error:
|
||||
error_message = f"AWS Polly TTS error: {str(error)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
finally:
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
|
||||
class PollyTTSService(AWSPollyTTSService):
|
||||
"""Deprecated alias for AWSPollyTTSService.
|
||||
|
||||
@@ -21,7 +21,6 @@ from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -331,8 +330,8 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_text_frames=False, # We'll push text frames based on word timestamps
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
pause_frame_processing=True,
|
||||
supports_word_timestamps=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -346,7 +345,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
self._audio_queue = asyncio.Queue()
|
||||
self._word_boundary_queue = asyncio.Queue()
|
||||
self._word_processor_task = None
|
||||
self._first_chunk = True
|
||||
self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds
|
||||
self._current_sentence_base_offset: float = 0.0 # Base offset for current sentence
|
||||
self._current_sentence_duration: float = 0.0 # Duration from Azure callback
|
||||
@@ -619,7 +617,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset TTS state between turns."""
|
||||
self._first_chunk = True
|
||||
self._cumulative_audio_offset = 0.0
|
||||
self._current_sentence_base_offset = 0.0
|
||||
self._current_sentence_duration = 0.0
|
||||
@@ -628,7 +625,7 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
self._last_timestamp = None
|
||||
self._current_context_id = None
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio data."""
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
|
||||
@@ -694,9 +691,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
return
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._first_chunk = True
|
||||
self._current_context_id = context_id
|
||||
|
||||
# Capture base offset BEFORE starting synthesis to avoid race conditions
|
||||
@@ -719,11 +713,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
|
||||
yield ErrorFrame(error=str(chunk))
|
||||
break
|
||||
|
||||
if self._first_chunk:
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
self._first_chunk = False
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
@@ -833,6 +822,8 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -887,8 +878,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
ssml = self._construct_ssml(text)
|
||||
|
||||
result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml)
|
||||
@@ -896,7 +885,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
# Azure always sends a 44-byte header. Strip it off.
|
||||
yield TTSAudioRawFrame(
|
||||
audio=result.audio_data[44:],
|
||||
@@ -904,7 +892,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
elif result.reason == ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
|
||||
@@ -29,8 +29,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -271,6 +269,8 @@ class CambTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -332,8 +332,6 @@ class CambTTSService(TTSService):
|
||||
text = text[:3000]
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Build SDK parameters
|
||||
tts_kwargs: Dict[str, Any] = {
|
||||
"text": text,
|
||||
@@ -348,7 +346,6 @@ class CambTTSService(TTSService):
|
||||
tts_kwargs["user_instructions"] = self._settings.user_instructions
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
assert self._client is not None, "Camb.ai TTS service not initialized"
|
||||
|
||||
@@ -384,5 +381,3 @@ class CambTTSService(TTSService):
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Camb.ai TTS error: {e}")
|
||||
finally:
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
|
||||
@@ -203,7 +203,7 @@ class CartesiaTTSSettings(TTSSettings):
|
||||
pronunciation_dict_id: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
|
||||
class CartesiaTTSService(AudioContextTTSService):
|
||||
class CartesiaTTSService(WebsocketTTSService):
|
||||
"""Cartesia TTS service with WebSocket streaming and word timestamps.
|
||||
|
||||
Provides text-to-speech using Cartesia's streaming WebSocket API.
|
||||
@@ -334,9 +334,9 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
supports_word_timestamps=True,
|
||||
pause_frame_processing=False,
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
text_aggregator=text_aggregator,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -452,7 +452,11 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
return list(zip(words, starts))
|
||||
|
||||
def _build_msg(
|
||||
self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True
|
||||
self,
|
||||
text: str = "",
|
||||
continue_transcript: bool = True,
|
||||
add_timestamps: bool = True,
|
||||
context_id: str = "",
|
||||
):
|
||||
voice_config = {}
|
||||
voice_config["mode"] = "id"
|
||||
@@ -461,7 +465,7 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
msg = {
|
||||
"transcript": text,
|
||||
"continue": continue_transcript,
|
||||
"context_id": self.get_active_audio_context_id(),
|
||||
"context_id": context_id,
|
||||
"model_id": self._settings.model,
|
||||
"voice": voice_config,
|
||||
"output_format": {
|
||||
@@ -580,15 +584,19 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio and finalize the current context."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio and finalize the current context.
|
||||
|
||||
Args:
|
||||
context_id: The specific context to flush. If None, falls back to the
|
||||
currently active context.
|
||||
"""
|
||||
flush_id = context_id or self.get_active_audio_context_id()
|
||||
if not flush_id or not self._websocket:
|
||||
return
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
msg = self._build_msg(text="", continue_transcript=False)
|
||||
msg = self._build_msg(text="", continue_transcript=False, context_id=flush_id)
|
||||
await self._websocket.send(msg)
|
||||
self.reset_active_audio_context()
|
||||
|
||||
async def _process_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
@@ -607,8 +615,6 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
)
|
||||
await self.add_word_timestamps(processed_timestamps, ctx_id)
|
||||
elif msg["type"] == "chunk":
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self.sample_rate,
|
||||
@@ -652,12 +658,7 @@ class CartesiaTTSService(AudioContextTTSService):
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
msg = self._build_msg(text=text, context_id=context_id)
|
||||
|
||||
try:
|
||||
await self._get_websocket().send(msg)
|
||||
@@ -777,6 +778,8 @@ class CartesiaHttpTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -863,8 +866,6 @@ class CartesiaHttpTTSService(TTSService):
|
||||
try:
|
||||
voice_config = {"mode": "id", "id": self._settings.voice}
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
output_format = {
|
||||
"container": self._output_container,
|
||||
"encoding": self._output_encoding,
|
||||
@@ -889,8 +890,6 @@ class CartesiaHttpTTSService(TTSService):
|
||||
if self._settings.pronunciation_dict_id:
|
||||
payload["pronunciation_dict_id"] = self._settings.pronunciation_dict_id
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
headers = {
|
||||
"Cartesia-Version": self._cartesia_version,
|
||||
"X-API-Key": self._api_key,
|
||||
@@ -922,4 +921,3 @@ class CartesiaHttpTTSService(TTSService):
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -328,7 +328,7 @@ class DeepgramSageMakerTTSService(TTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending Clear message: {e}")
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
@@ -355,12 +355,12 @@ class DeepgramSageMakerTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._ttfb_started:
|
||||
await self.start_ttfb_metrics()
|
||||
self._ttfb_started = True
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
if not self._ttfb_started:
|
||||
await self.start_ttfb_metrics()
|
||||
self._ttfb_started = True
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
|
||||
await self._client.send_json({"type": "Speak", "text": text})
|
||||
|
||||
@@ -26,8 +26,6 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
|
||||
@@ -120,6 +118,7 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
sample_rate=sample_rate,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
append_trailing_space=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -130,7 +129,6 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
self._encoding = encoding
|
||||
|
||||
self._receive_task = None
|
||||
self._context_id: Optional[str] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
@@ -267,7 +265,6 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -299,7 +296,9 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
if isinstance(message, bytes):
|
||||
# Binary message contains audio data
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(message, self.sample_rate, 1, context_id=self._context_id)
|
||||
frame = TTSAudioRawFrame(
|
||||
message, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
elif isinstance(message, str):
|
||||
# Text message contains metadata or control messages
|
||||
@@ -326,7 +325,7 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis by sending Flush command.
|
||||
|
||||
This should be called when the LLM finishes a complete response to force
|
||||
@@ -357,13 +356,8 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
# Store context_id for use in _receive_messages
|
||||
self._context_id = context_id
|
||||
|
||||
# Send text message to Deepgram
|
||||
# Note: We don't send Flush here - that should only be sent when the
|
||||
# LLM finishes a complete response via flush_audio()
|
||||
@@ -435,6 +429,8 @@ class DeepgramHttpTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -492,7 +488,6 @@ class DeepgramHttpTTSService(TTSService):
|
||||
raise Exception(f"HTTP {response.status}: {error_text}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
@@ -510,7 +505,5 @@ class DeepgramHttpTTSService(TTSService):
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
|
||||
@@ -46,9 +46,9 @@ from pipecat.frames.frames import (
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextTTSService,
|
||||
TextAggregationMode,
|
||||
TTSService,
|
||||
WebsocketTTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
@@ -308,7 +308,7 @@ def calculate_word_times(
|
||||
return (word_times, new_partial_word, new_partial_word_start_time)
|
||||
|
||||
|
||||
class ElevenLabsTTSService(AudioContextTTSService):
|
||||
class ElevenLabsTTSService(WebsocketTTSService):
|
||||
"""ElevenLabs WebSocket-based TTS service with word timestamps.
|
||||
|
||||
Provides real-time text-to-speech using ElevenLabs' WebSocket streaming API.
|
||||
@@ -479,7 +479,6 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
supports_word_timestamps=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -559,20 +558,15 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
)
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
elif voice_settings_changed and self.has_active_audio_context():
|
||||
elif voice_settings_changed:
|
||||
logger.debug(
|
||||
f"Voice settings changed ({changed.keys() & ElevenLabsTTSSettings.VOICE_SETTINGS_FIELDS}), "
|
||||
f"closing current context to apply changes"
|
||||
)
|
||||
context_id = self.get_active_audio_context_id()
|
||||
try:
|
||||
if self._websocket:
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self.reset_active_audio_context()
|
||||
audio_contexts = self.get_audio_contexts()
|
||||
if audio_contexts:
|
||||
for ctx_id in audio_contexts:
|
||||
await self._close_context(ctx_id)
|
||||
|
||||
if not url_changed:
|
||||
# Reconnect applies all settings; only warn about fields not handled
|
||||
@@ -610,13 +604,18 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio and finalize the current context."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio and finalize the current context.
|
||||
|
||||
Args:
|
||||
context_id: The specific context to flush. If None, falls back to the
|
||||
currently active context.
|
||||
"""
|
||||
flush_id = context_id or self.get_active_audio_context_id()
|
||||
if not flush_id or not self._websocket:
|
||||
return
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
msg = {"context_id": context_id, "flush": True}
|
||||
msg = {"context_id": flush_id, "flush": True}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -703,9 +702,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from ElevenLabs")
|
||||
# Close all contexts and the socket
|
||||
if self.has_active_audio_context():
|
||||
await self._websocket.send(json.dumps({"close_socket": True}))
|
||||
await self._websocket.send(json.dumps({"close_socket": True}))
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from ElevenLabs")
|
||||
except Exception as e:
|
||||
@@ -737,6 +734,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._cumulative_time = 0.0
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
@@ -782,9 +780,6 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
continue
|
||||
|
||||
if msg.get("audio"):
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=received_ctx_id)
|
||||
await self.append_to_audio_context(received_ctx_id, frame)
|
||||
@@ -845,9 +840,8 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
logger.warning(f"{self} keepalive error: {e}")
|
||||
break
|
||||
|
||||
async def _send_text(self, text: str):
|
||||
async def _send_text(self, text: str, context_id: str):
|
||||
"""Send text to the WebSocket for synthesis."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if self._websocket and context_id:
|
||||
msg = {"text": text, "context_id": context_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
@@ -870,16 +864,14 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
if not self.has_active_audio_context():
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._cumulative_time = 0
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
# Initialize context with voice settings and pronunciation dictionaries
|
||||
msg = {"text": " ", "context_id": context_id}
|
||||
if self._voice_settings:
|
||||
@@ -892,7 +884,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
logger.trace(f"Created new context {context_id}")
|
||||
|
||||
await self._send_text(text)
|
||||
await self._send_text(text, context_id)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
@@ -1046,7 +1038,7 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
supports_word_timestamps=True,
|
||||
push_start_frame=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -1266,8 +1258,6 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
params["optimize_streaming_latency"] = self._settings.optimize_streaming_latency
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._session.post(
|
||||
url, json=payload, headers=headers, params=params
|
||||
) as response:
|
||||
@@ -1278,10 +1268,6 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Start TTS sequence
|
||||
await self.start_word_timestamps()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
# Track the duration of this utterance based on the last character's end time
|
||||
utterance_duration = 0
|
||||
async for line in response.content:
|
||||
@@ -1347,4 +1333,3 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
# Let the parent class handle TTSStoppedFrame
|
||||
|
||||
@@ -209,6 +209,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
@@ -219,7 +220,6 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
self._base_url = "wss://api.fish.audio/v1/tts/live"
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._request_id = None
|
||||
|
||||
# Init-only audio format config (not runtime-updatable).
|
||||
self._fish_sample_rate = 0 # Set in start()
|
||||
@@ -341,11 +341,10 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any buffered audio by sending a flush event to Fish Audio."""
|
||||
logger.trace(f"{self}: Flushing audio buffers")
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
@@ -361,7 +360,6 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
self._request_id = None
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
@@ -398,12 +396,6 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if not self._request_id:
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._request_id = str(uuid.uuid4())
|
||||
|
||||
# Send the text
|
||||
text_message = {
|
||||
"event": "text",
|
||||
|
||||
@@ -34,8 +34,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import (
|
||||
NOT_GIVEN,
|
||||
@@ -655,6 +653,8 @@ class GoogleHttpTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -803,8 +803,6 @@ class GoogleHttpTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Check if the voice is a Chirp voice (including Chirp 3) or Journey voice
|
||||
is_chirp_voice = "chirp" in self._settings.voice.lower()
|
||||
is_journey_voice = "journey" in self._settings.voice.lower()
|
||||
@@ -840,8 +838,6 @@ class GoogleHttpTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
# Skip the first 44 bytes to remove the WAV header
|
||||
audio_content = response.audio_content[44:]
|
||||
|
||||
@@ -855,8 +851,6 @@ class GoogleHttpTTSService(TTSService):
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
|
||||
yield frame
|
||||
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
@@ -967,8 +961,6 @@ class GoogleBaseTTSService(TTSService):
|
||||
streaming_responses = await self._client.streaming_synthesize(request_generator())
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
audio_buffer = b""
|
||||
first_chunk_for_ttfb = False
|
||||
|
||||
@@ -992,8 +984,6 @@ class GoogleBaseTTSService(TTSService):
|
||||
if audio_buffer:
|
||||
yield TTSAudioRawFrame(audio_buffer, self.sample_rate, 1, context_id=context_id)
|
||||
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
|
||||
class GoogleTTSService(GoogleBaseTTSService):
|
||||
"""Google Cloud Text-to-Speech streaming service.
|
||||
@@ -1096,6 +1086,8 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1135,8 +1127,6 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Build voice selection params
|
||||
if self._voice_cloning_key:
|
||||
voice_clone_params = texttospeech_v1.VoiceCloneParams(
|
||||
@@ -1352,6 +1342,8 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1414,8 +1406,6 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Build voice selection params
|
||||
if self._settings.multi_speaker and self._settings.speaker_configs:
|
||||
# Multi-speaker mode
|
||||
|
||||
@@ -19,11 +19,10 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import AudioContextTTSService
|
||||
from pipecat.services.tts_service import WebsocketTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
@@ -45,7 +44,7 @@ class GradiumTTSSettings(TTSSettings):
|
||||
pass
|
||||
|
||||
|
||||
class GradiumTTSService(AudioContextTTSService):
|
||||
class GradiumTTSService(WebsocketTTSService):
|
||||
"""Text-to-Speech service using Gradium's websocket API."""
|
||||
|
||||
_settings: GradiumTTSSettings
|
||||
@@ -125,9 +124,9 @@ class GradiumTTSService(AudioContextTTSService):
|
||||
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
supports_word_timestamps=True,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -166,12 +165,9 @@ class GradiumTTSService(AudioContextTTSService):
|
||||
self._warn_unhandled_updated_settings(changed)
|
||||
return changed
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
def _build_msg(self, text: str = "", context_id: str = "") -> dict:
|
||||
"""Build JSON message for Gradium API."""
|
||||
msg = {"text": text, "type": "text"}
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id:
|
||||
msg["client_req_id"] = context_id
|
||||
msg = {"text": text, "type": "text", "client_req_id": context_id}
|
||||
return msg
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
@@ -280,15 +276,14 @@ class GradiumTTSService(AudioContextTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
flush_id = context_id or self.get_active_audio_context_id()
|
||||
if not flush_id or not self._websocket:
|
||||
return
|
||||
try:
|
||||
msg = {"type": "end_of_stream", "client_req_id": context_id}
|
||||
msg = {"type": "end_of_stream", "client_req_id": flush_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
self.reset_active_audio_context()
|
||||
except ConnectionClosedOK:
|
||||
logger.debug(f"{self}: connection closed normally during flush")
|
||||
except Exception as e:
|
||||
@@ -326,8 +321,6 @@ class GradiumTTSService(AudioContextTTSService):
|
||||
if msg["type"] == "audio":
|
||||
if not ctx_id or not self.audio_context_available(ctx_id):
|
||||
continue
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["audio"]),
|
||||
sample_rate=self.sample_rate,
|
||||
@@ -369,12 +362,7 @@ class GradiumTTSService(AudioContextTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
if not self.has_active_audio_context():
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
msg = self._build_msg(text=text, context_id=context_id)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
|
||||
@@ -18,8 +18,6 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -140,6 +138,8 @@ class GroqTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
pause_frame_processing=True,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -171,9 +171,6 @@ class GroqTTSService(TTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
measuring_ttfb = True
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
try:
|
||||
response = await self._client.audio.speech.create(
|
||||
model=self._settings.model,
|
||||
@@ -198,5 +195,3 @@ class GroqTTSService(TTSService):
|
||||
yield TTSAudioRawFrame(bytes, frame_rate, channels, context_id=context_id)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -22,7 +22,6 @@ from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -166,7 +165,7 @@ class HumeTTSService(TTSService):
|
||||
sample_rate=sample_rate,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
supports_word_timestamps=True,
|
||||
push_start_frame=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -181,7 +180,6 @@ class HumeTTSService(TTSService):
|
||||
|
||||
# Track cumulative time for word timestamps across utterances
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Can generate metrics.
|
||||
@@ -203,7 +201,6 @@ class HumeTTSService(TTSService):
|
||||
def _reset_state(self):
|
||||
"""Reset internal state variables."""
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
async def stop(self, frame: EndFrame) -> None:
|
||||
"""Stop the service and cleanup resources.
|
||||
@@ -310,15 +307,8 @@ class HumeTTSService(TTSService):
|
||||
# Request raw PCM chunks in the streaming JSON
|
||||
pcm_fmt = FormatPcm(type="pcm")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Start TTS sequence if not already started
|
||||
if not self._started:
|
||||
await self.start_word_timestamps()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._started = True
|
||||
|
||||
try:
|
||||
# Instant mode is always enabled here (not user-configurable)
|
||||
# Hume emits mono PCM at 48 kHz; downstream can resample if needed.
|
||||
@@ -395,4 +385,3 @@ class HumeTTSService(TTSService):
|
||||
finally:
|
||||
# Ensure TTFB timer is stopped even on early failures
|
||||
await self.stop_ttfb_metrics()
|
||||
# Let the parent class handle TTSStoppedFrame via push_stop_frames
|
||||
|
||||
@@ -62,7 +62,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
@@ -212,7 +212,7 @@ class InworldHttpTTSService(TTSService):
|
||||
super().__init__(
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
supports_word_timestamps=True,
|
||||
push_start_frame=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -359,11 +359,6 @@ class InworldHttpTTSService(TTSService):
|
||||
}
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
await self.start_word_timestamps()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
async with self._session.post(
|
||||
self._base_url, json=payload, headers=headers
|
||||
) as response:
|
||||
@@ -514,7 +509,7 @@ class InworldHttpTTSService(TTSService):
|
||||
)
|
||||
|
||||
|
||||
class InworldTTSService(AudioContextTTSService):
|
||||
class InworldTTSService(WebsocketTTSService):
|
||||
"""Inworld AI WebSocket-based TTS service.
|
||||
|
||||
Uses bidirectional WebSocket for lower latency streaming. Supports multiple
|
||||
@@ -650,7 +645,6 @@ class InworldTTSService(AudioContextTTSService):
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
supports_word_timestamps=True,
|
||||
sample_rate=sample_rate,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
@@ -719,17 +713,17 @@ class InworldTTSService(AudioContextTTSService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio without closing the context.
|
||||
|
||||
This triggers synthesis of all accumulated text in the buffer while
|
||||
keeping the context open for subsequent text. The context is only
|
||||
closed on interruption, disconnect, or end of session.
|
||||
"""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id and self._websocket:
|
||||
logger.trace(f"Flushing audio for context {context_id}")
|
||||
await self._send_flush(context_id)
|
||||
flush_id = context_id or self.get_active_audio_context_id()
|
||||
if flush_id and self._websocket:
|
||||
logger.trace(f"Flushing audio for context {flush_id}")
|
||||
await self._send_flush(flush_id)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
@@ -899,12 +893,10 @@ class InworldTTSService(AudioContextTTSService):
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Inworld WebSocket TTS")
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if context_id:
|
||||
try:
|
||||
await self._send_close_context(context_id)
|
||||
except Exception:
|
||||
pass
|
||||
audio_contexts = self.get_audio_contexts()
|
||||
if audio_contexts:
|
||||
for ctx_id in audio_contexts:
|
||||
await self._send_close_context(ctx_id)
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from Inworld WebSocket TTS")
|
||||
except Exception as e:
|
||||
@@ -934,10 +926,7 @@ class InworldTTSService(AudioContextTTSService):
|
||||
for k in ["contextCreated", "audioChunk", "flushCompleted", "contextClosed"]
|
||||
if k in result
|
||||
]
|
||||
logger.debug(
|
||||
f"{self}: Received message types={msg_types}, ctx_id={ctx_id}, "
|
||||
f"current_ctx={self.get_active_audio_context_id()}, available={self.audio_context_available(ctx_id) if ctx_id else 'N/A'}"
|
||||
)
|
||||
logger.debug(f"{self}: Received message types={msg_types}, ctx_id={ctx_id}")
|
||||
|
||||
# Check for errors
|
||||
status = result.get("status", {})
|
||||
@@ -948,9 +937,7 @@ class InworldTTSService(AudioContextTTSService):
|
||||
# Handle "Context not found" error (code 5)
|
||||
# This can happen when a keepalive message is sent but no context is available.
|
||||
if error_code == 5 and "not found" in error_msg.lower():
|
||||
logger.debug(
|
||||
f"{self}: Context {ctx_id or self.get_active_audio_context_id()} not found."
|
||||
)
|
||||
logger.debug(f"{self}: Context {ctx_id} not found.")
|
||||
continue
|
||||
|
||||
# For other errors, push error frame
|
||||
@@ -961,17 +948,10 @@ class InworldTTSService(AudioContextTTSService):
|
||||
await self.push_error(error_msg=str(msg["error"]))
|
||||
continue
|
||||
|
||||
# Check if this message belongs to an available context.
|
||||
# If the context isn't available but matches our current context ID,
|
||||
# recreate it (handles race conditions during interruption recovery).
|
||||
# If the context isn't available recreate it (handles race conditions during interruption recovery).
|
||||
if ctx_id and not self.audio_context_available(ctx_id):
|
||||
if self.get_active_audio_context_id() == ctx_id:
|
||||
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
|
||||
await self.create_audio_context(ctx_id)
|
||||
else:
|
||||
# This is a message from an old/closed context - skip it
|
||||
logger.trace(f"{self}: Skipping message from unavailable context: {ctx_id}")
|
||||
continue
|
||||
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
|
||||
await self.create_audio_context(ctx_id)
|
||||
|
||||
# Process audio chunk
|
||||
audio_chunk = result.get("audioChunk", {})
|
||||
@@ -979,8 +959,6 @@ class InworldTTSService(AudioContextTTSService):
|
||||
|
||||
if audio_b64:
|
||||
logger.trace(f"{self}: Processing audio chunk for context {ctx_id}")
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
audio = base64.b64decode(audio_b64)
|
||||
if len(audio) > 44 and audio.startswith(b"RIFF"):
|
||||
audio = audio[44:]
|
||||
@@ -1012,12 +990,8 @@ class InworldTTSService(AudioContextTTSService):
|
||||
if "contextClosed" in result:
|
||||
logger.trace(f"{self}: Context closed on server: {ctx_id}")
|
||||
await self.stop_ttfb_metrics()
|
||||
# Only reset if this is our current context
|
||||
if ctx_id == self.get_active_audio_context_id():
|
||||
self.reset_active_audio_context()
|
||||
if ctx_id and self.audio_context_available(ctx_id):
|
||||
await self.remove_audio_context(ctx_id)
|
||||
await self.add_word_timestamps([("TTSStoppedFrame", 0), ("Reset", 0)], ctx_id)
|
||||
await self.remove_audio_context(ctx_id)
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic keepalive messages to maintain WebSocket connection."""
|
||||
@@ -1128,10 +1102,10 @@ class InworldTTSService(AudioContextTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
if not self.has_active_audio_context():
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
await self.create_audio_context(context_id)
|
||||
await self._send_context(context_id)
|
||||
|
||||
await self._send_text(context_id, text)
|
||||
|
||||
@@ -20,8 +20,6 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -170,6 +168,8 @@ class KokoroTTSService(TTSService):
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -212,9 +212,7 @@ class KokoroTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
stream = self._kokoro.create_stream(
|
||||
text, voice=self._settings.voice, lang=self._settings.language, speed=1.0
|
||||
@@ -238,4 +236,3 @@ class KokoroTTSService(TTSService):
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -143,6 +143,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
@@ -152,7 +153,6 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
self._api_key = api_key
|
||||
self._output_format = "raw"
|
||||
self._receive_task = None
|
||||
self._context_id: Optional[str] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -289,7 +289,6 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting from LMNT: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -299,7 +298,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
return
|
||||
@@ -315,7 +314,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
audio=message,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=self._context_id,
|
||||
context_id=self.get_active_audio_context_id(),
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
@@ -347,11 +346,6 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
# Store context_id for use in _receive_messages
|
||||
self._context_id = context_id
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
# Send text to LMNT
|
||||
await self._get_websocket().send(json.dumps({"text": text}))
|
||||
# Force synthesis
|
||||
|
||||
@@ -23,8 +23,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -305,6 +303,8 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -402,8 +402,6 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
payload["language_boost"] = self._settings.language_boost
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._session.post(
|
||||
self._base_url, headers=headers, json=payload
|
||||
) as response:
|
||||
@@ -413,7 +411,6 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
# Process the streaming response
|
||||
buffer = bytearray()
|
||||
@@ -490,4 +487,3 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -180,6 +180,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
stop_frame_timeout_s=2.0,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
@@ -188,12 +189,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
|
||||
self._cumulative_time = 0
|
||||
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
self._context_id: Optional[str] = None
|
||||
self._encoding = encoding
|
||||
self._sampling_rate = sample_rate
|
||||
|
||||
@@ -252,7 +249,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis by sending stop command."""
|
||||
if self._websocket:
|
||||
msg = {"text": "<STOP>"}
|
||||
@@ -358,7 +355,6 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -372,7 +368,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
audio = base64.b64decode(msg["data"]["audio"])
|
||||
frame = TTSAudioRawFrame(
|
||||
audio, self.sample_rate, 1, context_id=self._context_id
|
||||
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -415,12 +411,6 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
# Store context_id for use in _receive_messages
|
||||
self._context_id = context_id
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._cumulative_time = 0
|
||||
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
@@ -523,6 +513,8 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -559,7 +551,7 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis.
|
||||
|
||||
Note:
|
||||
@@ -633,8 +625,6 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
payload["voice_id"] = self._settings.voice
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
@@ -643,7 +633,6 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
# Process SSE stream line by line
|
||||
async for line in response.content:
|
||||
@@ -681,4 +670,3 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -28,8 +28,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -145,6 +143,8 @@ class NvidiaTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -271,9 +271,6 @@ class NvidiaTTSService(TTSService):
|
||||
assert self._service is not None, "TTS service not initialized"
|
||||
assert self._config is not None, "Synthesis configuration not created"
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
responses = await asyncio.to_thread(read_audio_responses)
|
||||
@@ -289,7 +286,6 @@ class NvidiaTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -194,6 +192,8 @@ class OpenAITTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -234,8 +234,6 @@ class OpenAITTSService(TTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Setup API parameters
|
||||
create_params = {
|
||||
"input": text,
|
||||
@@ -267,12 +265,10 @@ class OpenAITTSService(TTSService):
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
async for chunk in r.iter_bytes(CHUNK_SIZE):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
|
||||
yield frame
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
except BadRequestError as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
@@ -17,8 +17,6 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -91,6 +89,8 @@ class PiperTTSService(TTSService):
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -159,12 +159,8 @@ class PiperTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
async for frame in self._stream_audio_frames_from_iterator(
|
||||
async_iterator(self._voice.synthesize(text)),
|
||||
in_sample_rate=self._voice.config.sample_rate,
|
||||
@@ -178,7 +174,6 @@ class PiperTTSService(TTSService):
|
||||
finally:
|
||||
logger.debug(f"{self}: Finished TTS [{text}]")
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
|
||||
# This assumes a running TTS service running:
|
||||
@@ -244,6 +239,8 @@ class PiperHttpTTSService(TTSService):
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -279,8 +276,6 @@ class PiperHttpTTSService(TTSService):
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
data = {
|
||||
"text": text,
|
||||
"voice": self._settings.voice,
|
||||
@@ -296,8 +291,6 @@ class PiperHttpTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
async for frame in self._stream_audio_frames_from_iterator(
|
||||
@@ -311,4 +304,3 @@ class PiperHttpTTSService(TTSService):
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
@@ -24,7 +24,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import AudioContextTTSService
|
||||
from pipecat.services.tts_service import WebsocketTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
@@ -43,7 +43,7 @@ class ResembleAITTSSettings(TTSSettings):
|
||||
pass
|
||||
|
||||
|
||||
class ResembleAITTSService(AudioContextTTSService):
|
||||
class ResembleAITTSService(WebsocketTTSService):
|
||||
"""Resemble AI TTS service with WebSocket streaming and word timestamps.
|
||||
|
||||
Provides text-to-speech using Resemble AI's streaming WebSocket API.
|
||||
@@ -103,7 +103,6 @@ class ResembleAITTSService(AudioContextTTSService):
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
reuse_context_id_within_turn=False,
|
||||
supports_word_timestamps=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -268,7 +267,7 @@ class ResembleAITTSService(AudioContextTTSService):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio and finalize the current context."""
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
# For Resemble AI, we just wait for the audio_end message
|
||||
@@ -297,9 +296,6 @@ class ResembleAITTSService(AudioContextTTSService):
|
||||
continue
|
||||
|
||||
if msg_type == "audio":
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
|
||||
# Decode base64 audio content
|
||||
audio_content = msg.get("audio_content", "")
|
||||
if not audio_content:
|
||||
@@ -447,14 +443,14 @@ class ResembleAITTSService(AudioContextTTSService):
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
# Map request_id to context_id for tracking
|
||||
self._request_id_to_context[self._request_id_counter] = context_id
|
||||
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
|
||||
try:
|
||||
|
||||
@@ -33,10 +33,10 @@ from pipecat.frames.frames import (
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextTTSService,
|
||||
InterruptibleTTSService,
|
||||
TextAggregationMode,
|
||||
TTSService,
|
||||
WebsocketTTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
@@ -123,7 +123,7 @@ class RimeNonJsonTTSSettings(TTSSettings):
|
||||
_aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"}
|
||||
|
||||
|
||||
class RimeTTSService(AudioContextTTSService):
|
||||
class RimeTTSService(WebsocketTTSService):
|
||||
"""Text-to-Speech service using Rime's websocket API.
|
||||
|
||||
Uses Rime's websocket JSON API to convert text to speech with word-level timing
|
||||
@@ -276,7 +276,6 @@ class RimeTTSService(AudioContextTTSService):
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
supports_word_timestamps=True,
|
||||
append_trailing_space=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
@@ -408,9 +407,9 @@ class RimeTTSService(AudioContextTTSService):
|
||||
|
||||
return changed
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
def _build_msg(self, text: str = "", context_id: str = "") -> dict:
|
||||
"""Build JSON message for Rime API."""
|
||||
msg = {"text": text, "contextId": self.get_active_audio_context_id()}
|
||||
msg = {"text": text, "contextId": context_id}
|
||||
if self._extra_msg_fields:
|
||||
msg |= self._extra_msg_fields
|
||||
self._extra_msg_fields = {}
|
||||
@@ -557,15 +556,14 @@ class RimeTTSService(AudioContextTTSService):
|
||||
|
||||
return word_pairs
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis."""
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if not context_id or not self._websocket:
|
||||
flush_id = context_id or self.get_active_audio_context_id()
|
||||
if not flush_id or not self._websocket:
|
||||
return
|
||||
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
await self._get_websocket().send(json.dumps({"operation": "flush"}))
|
||||
self.reset_active_audio_context()
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages."""
|
||||
@@ -578,8 +576,6 @@ class RimeTTSService(AudioContextTTSService):
|
||||
context_id = msg["contextId"]
|
||||
if msg["type"] == "chunk":
|
||||
# Process audio chunk
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self.sample_rate,
|
||||
@@ -638,13 +634,13 @@ class RimeTTSService(AudioContextTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
if not self.has_active_audio_context():
|
||||
if not self.audio_context_available(context_id):
|
||||
await self.create_audio_context(context_id)
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._cumulative_time = 0
|
||||
await self.create_audio_context(context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
msg = self._build_msg(text=text, context_id=context_id)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
@@ -773,6 +769,8 @@ class RimeHttpTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -844,8 +842,6 @@ class RimeHttpTTSService(TTSService):
|
||||
need_to_strip_wav_header = False
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._session.post(
|
||||
self._base_url, json=payload, headers=headers
|
||||
) as response:
|
||||
@@ -856,8 +852,6 @@ class RimeHttpTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
async for frame in self._stream_audio_frames_from_iterator(
|
||||
@@ -872,7 +866,6 @@ class RimeHttpTTSService(TTSService):
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
|
||||
class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
@@ -1005,6 +998,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
pause_frame_processing=True,
|
||||
append_trailing_space=True,
|
||||
settings=default_settings,
|
||||
@@ -1022,7 +1016,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
self._settings.extra.update(params.extra)
|
||||
|
||||
self._receive_task = None
|
||||
self._context_id: Optional[str] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -1138,7 +1131,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -1148,7 +1140,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket:
|
||||
return
|
||||
@@ -1168,7 +1160,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
audio=message,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=self._context_id,
|
||||
context_id=self.get_active_audio_context_id(),
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
@@ -1190,10 +1182,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
# Store context_id for use in _receive_messages
|
||||
self._context_id = context_id
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
# Send bare text (not JSON)
|
||||
await self._get_websocket().send(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
@@ -524,6 +524,8 @@ class SarvamHttpTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -573,8 +575,6 @@ class SarvamHttpTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Build payload with common parameters
|
||||
payload = {
|
||||
"text": text,
|
||||
@@ -606,8 +606,6 @@ class SarvamHttpTTSService(TTSService):
|
||||
|
||||
url = f"{self._base_url}/text-to-speech"
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
@@ -645,7 +643,6 @@ class SarvamHttpTTSService(TTSService):
|
||||
yield ErrorFrame(error=f"Error generating TTS: {e}", exception=e)
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
|
||||
class SarvamTTSService(InterruptibleTTSService):
|
||||
@@ -951,6 +948,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
push_text_frames=True,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
@@ -967,7 +965,6 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
self._context_id: Optional[str] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -1018,7 +1015,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def flush_audio(self):
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis by sending flush command."""
|
||||
try:
|
||||
if self._websocket:
|
||||
@@ -1151,7 +1148,6 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -1170,7 +1166,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
audio = base64.b64decode(msg["data"]["audio"])
|
||||
frame = TTSAudioRawFrame(
|
||||
audio, self.sample_rate, 1, context_id=self._context_id
|
||||
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
elif msg.get("type") == "error":
|
||||
@@ -1224,10 +1220,6 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
# Store context_id for use in _receive_messages
|
||||
self._context_id = context_id
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
|
||||
@@ -19,8 +19,6 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -135,6 +133,8 @@ class SpeechmaticsTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -185,9 +185,6 @@ class SpeechmaticsTTSService(TTSService):
|
||||
url = _get_endpoint_url(self._base_url, self._settings.voice, self.sample_rate)
|
||||
|
||||
try:
|
||||
# Start TTS TTFB metrics
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Track attempt
|
||||
attempt = 0
|
||||
|
||||
@@ -238,9 +235,6 @@ class SpeechmaticsTTSService(TTSService):
|
||||
# Update Pipecat metrics
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Emit the TTS started frame
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
# Process the response in streaming chunks
|
||||
first_chunk = True
|
||||
buffer = b""
|
||||
@@ -277,8 +271,7 @@ class SpeechmaticsTTSService(TTSService):
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Error generating TTS: {e}")
|
||||
finally:
|
||||
# Emit the TTS stopped frame
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
|
||||
def _get_endpoint_url(base_url: str, voice: str, sample_rate: int) -> str:
|
||||
|
||||
@@ -11,7 +11,7 @@ text-to-speech synthesis using local Docker deployment.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
|
||||
from pipecat.services.tts_service import TTSService
|
||||
@@ -132,6 +130,8 @@ class XTTSService(TTSService):
|
||||
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -213,8 +213,6 @@ class XTTSService(TTSService):
|
||||
"stream_chunk_size": 20,
|
||||
}
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._aiohttp_session.post(url, json=payload) as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
@@ -223,8 +221,6 @@ class XTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
buffer = bytearray()
|
||||
@@ -262,5 +258,3 @@ class XTTSService(TTSService):
|
||||
resampled_audio, self.sample_rate, 1, context_id=context_id
|
||||
)
|
||||
yield frame
|
||||
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
|
||||
Reference in New Issue
Block a user