Merge pull request #3804 from pipecat-ai/filipi/concurrent_audio_contexts

Allowing concurrent audio contexts
This commit is contained in:
Filipi da Silva Fuchter
2026-03-06 14:49:57 -05:00
committed by GitHub
35 changed files with 698 additions and 765 deletions

1
changelog/3804.added.md Normal file
View File

@@ -0,0 +1 @@
- Added concurrent audio context support: `CartesiaTTSService` can now synthesize the next sentence while the previous one is still playing, by setting `pause_frame_processing=False` and routing each sentence through its own audio context queue.

View File

@@ -0,0 +1 @@
- Audio context management (previously in `AudioContextTTSService`) is now built into `TTSService`. All WebSocket providers (`cartesia`, `elevenlabs`, `asyncai`, `inworld`, `rime`, `gradium`, `resembleai`) now inherit from `WebsocketTTSService` directly. Word-timestamp baseline is set automatically on the first audio chunk of each context instead of requiring each provider to call `start_word_timestamps()` in their receive loop.

View File

@@ -0,0 +1,2 @@
- Deprecated `AudioContextTTSService` and `AudioContextWordTTSService`. Subclass `WebsocketTTSService` directly instead; audio context management is now part of the base `TTSService`.
- Deprecated `WordTTSService`, `WebsocketWordTTSService`, and `InterruptibleWordTTSService`. Word timestamp logic is now always active in `TTSService` and no longer needs to be opted into via a subclass.

View File

@@ -0,0 +1 @@
- ⚠️ Removed `supports_word_timestamps` parameter from `TTSService.__init__()`. Word timestamp logic is now always active. Remove this argument from any custom subclass `super().__init__()` calls.

View File

@@ -10,8 +10,7 @@ from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame, TTSTextFrame
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -25,7 +24,6 @@ from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.inworld.tts import InworldTTSService, InworldTTSSettings
from pipecat.services.openai.llm import OpenAILLMService, OpenAILLMSettings
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -94,13 +92,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
enable_metrics=True,
enable_usage_metrics=True,
),
observers=[
DebugLogObserver(
frame_types={
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
}
),
],
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)

View File

@@ -23,12 +23,11 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -80,7 +79,7 @@ class AsyncAITTSSettings(TTSSettings):
pass
class AsyncAITTSService(AudioContextTTSService):
class AsyncAITTSService(WebsocketTTSService):
"""Async TTS service with WebSocket streaming.
Provides text-to-speech using Async's streaming WebSocket API.
@@ -183,8 +182,9 @@ class AsyncAITTSService(AudioContextTTSService):
aggregate_sentences=aggregate_sentences,
text_aggregation_mode=text_aggregation_mode,
pause_frame_processing=True,
push_stop_frames=True,
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -340,13 +340,18 @@ class AsyncAITTSService(AudioContextTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
"""Flush any pending audio."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio.
Args:
context_id: The specific context to flush. If None, falls back to the
currently active context.
"""
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
msg = self._build_msg(text=" ", context_id=context_id, force=True)
msg = self._build_msg(text=" ", context_id=flush_id, force=True)
await self._websocket.send(msg)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -459,12 +464,6 @@ class AsyncAITTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
msg = self._build_msg(text=text, force=True, context_id=context_id)
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
@@ -574,6 +573,8 @@ class AsyncAIHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -632,7 +633,7 @@ class AsyncAIHttpTTSService(TTSService):
try:
voice_config = {"mode": "id", "id": self._settings.voice}
await self.start_ttfb_metrics()
payload = {
"model_id": self._settings.model,
"transcript": text,
@@ -644,7 +645,7 @@ class AsyncAIHttpTTSService(TTSService):
},
"language": self._settings.language,
}
yield TTSStartedFrame(context_id=context_id)
headers = {
"version": self._api_version,
"x-api-key": self._api_key,
@@ -682,4 +683,3 @@ class AsyncAIHttpTTSService(TTSService):
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -247,6 +245,8 @@ class AWSPollyTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -329,8 +329,6 @@ class AWSPollyTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Construct the parameters dictionary
ssml = self._construct_ssml(text)
@@ -362,8 +360,6 @@ class AWSPollyTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
for i in range(0, len(audio_data), CHUNK_SIZE):
@@ -373,14 +369,10 @@ class AWSPollyTTSService(TTSService):
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
yield frame
yield TTSStoppedFrame(context_id=context_id)
except (BotoCoreError, ClientError) as error:
error_message = f"AWS Polly TTS error: {str(error)}"
yield ErrorFrame(error=error_message)
finally:
yield TTSStoppedFrame(context_id=context_id)
class PollyTTSService(AWSPollyTTSService):
"""Deprecated alias for AWSPollyTTSService.

View File

@@ -21,7 +21,6 @@ from pipecat.frames.frames import (
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
@@ -331,8 +330,8 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
text_aggregation_mode=text_aggregation_mode,
push_text_frames=False, # We'll push text frames based on word timestamps
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -346,7 +345,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
self._audio_queue = asyncio.Queue()
self._word_boundary_queue = asyncio.Queue()
self._word_processor_task = None
self._first_chunk = True
self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds
self._current_sentence_base_offset: float = 0.0 # Base offset for current sentence
self._current_sentence_duration: float = 0.0 # Duration from Azure callback
@@ -619,7 +617,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
def _reset_state(self):
"""Reset TTS state between turns."""
self._first_chunk = True
self._cumulative_audio_offset = 0.0
self._current_sentence_base_offset = 0.0
self._current_sentence_duration = 0.0
@@ -628,7 +625,7 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
self._last_timestamp = None
self._current_context_id = None
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio data."""
logger.trace(f"{self}: flushing audio")
@@ -694,9 +691,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
return
try:
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
self._first_chunk = True
self._current_context_id = context_id
# Capture base offset BEFORE starting synthesis to avoid race conditions
@@ -719,11 +713,6 @@ class AzureTTSService(TTSService, AzureBaseTTSService):
yield ErrorFrame(error=str(chunk))
break
if self._first_chunk:
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
self._first_chunk = False
frame = TTSAudioRawFrame(
audio=chunk,
sample_rate=self.sample_rate,
@@ -833,6 +822,8 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -887,8 +878,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
await self.start_ttfb_metrics()
ssml = self._construct_ssml(text)
result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml)
@@ -896,7 +885,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
if result.reason == ResultReason.SynthesizingAudioCompleted:
await self.start_tts_usage_metrics(text)
await self.stop_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
# Azure always sends a 44-byte header. Strip it off.
yield TTSAudioRawFrame(
audio=result.audio_data[44:],
@@ -904,7 +892,6 @@ class AzureHttpTTSService(TTSService, AzureBaseTTSService):
num_channels=1,
context_id=context_id,
)
yield TTSStoppedFrame(context_id=context_id)
elif result.reason == ResultReason.Canceled:
cancellation_details = result.cancellation_details
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")

View File

@@ -29,8 +29,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -271,6 +269,8 @@ class CambTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -332,8 +332,6 @@ class CambTTSService(TTSService):
text = text[:3000]
try:
await self.start_ttfb_metrics()
# Build SDK parameters
tts_kwargs: Dict[str, Any] = {
"text": text,
@@ -348,7 +346,6 @@ class CambTTSService(TTSService):
tts_kwargs["user_instructions"] = self._settings.user_instructions
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
assert self._client is not None, "Camb.ai TTS service not initialized"
@@ -384,5 +381,3 @@ class CambTTSService(TTSService):
except Exception as e:
yield ErrorFrame(error=f"Camb.ai TTS error: {e}")
finally:
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
@@ -203,7 +203,7 @@ class CartesiaTTSSettings(TTSSettings):
pronunciation_dict_id: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
class CartesiaTTSService(AudioContextTTSService):
class CartesiaTTSService(WebsocketTTSService):
"""Cartesia TTS service with WebSocket streaming and word timestamps.
Provides text-to-speech using Cartesia's streaming WebSocket API.
@@ -334,9 +334,9 @@ class CartesiaTTSService(AudioContextTTSService):
text_aggregation_mode=text_aggregation_mode,
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
pause_frame_processing=True,
supports_word_timestamps=True,
pause_frame_processing=False,
sample_rate=sample_rate,
push_start_frame=True,
text_aggregator=text_aggregator,
settings=default_settings,
**kwargs,
@@ -452,7 +452,11 @@ class CartesiaTTSService(AudioContextTTSService):
return list(zip(words, starts))
def _build_msg(
self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True
self,
text: str = "",
continue_transcript: bool = True,
add_timestamps: bool = True,
context_id: str = "",
):
voice_config = {}
voice_config["mode"] = "id"
@@ -461,7 +465,7 @@ class CartesiaTTSService(AudioContextTTSService):
msg = {
"transcript": text,
"continue": continue_transcript,
"context_id": self.get_active_audio_context_id(),
"context_id": context_id,
"model_id": self._settings.model,
"voice": voice_config,
"output_format": {
@@ -580,15 +584,19 @@ class CartesiaTTSService(AudioContextTTSService):
"""
pass
async def flush_audio(self):
"""Flush any pending audio and finalize the current context."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio and finalize the current context.
Args:
context_id: The specific context to flush. If None, falls back to the
currently active context.
"""
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
msg = self._build_msg(text="", continue_transcript=False)
msg = self._build_msg(text="", continue_transcript=False, context_id=flush_id)
await self._websocket.send(msg)
self.reset_active_audio_context()
async def _process_messages(self):
async for message in self._get_websocket():
@@ -607,8 +615,6 @@ class CartesiaTTSService(AudioContextTTSService):
)
await self.add_word_timestamps(processed_timestamps, ctx_id)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self.sample_rate,
@@ -652,12 +658,7 @@ class CartesiaTTSService(AudioContextTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
if not self.has_active_audio_context():
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
await self.create_audio_context(context_id)
msg = self._build_msg(text=text)
msg = self._build_msg(text=text, context_id=context_id)
try:
await self._get_websocket().send(msg)
@@ -777,6 +778,8 @@ class CartesiaHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -863,8 +866,6 @@ class CartesiaHttpTTSService(TTSService):
try:
voice_config = {"mode": "id", "id": self._settings.voice}
await self.start_ttfb_metrics()
output_format = {
"container": self._output_container,
"encoding": self._output_encoding,
@@ -889,8 +890,6 @@ class CartesiaHttpTTSService(TTSService):
if self._settings.pronunciation_dict_id:
payload["pronunciation_dict_id"] = self._settings.pronunciation_dict_id
yield TTSStartedFrame(context_id=context_id)
headers = {
"Cartesia-Version": self._cartesia_version,
"X-API-Key": self._api_key,
@@ -922,4 +921,3 @@ class CartesiaHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -328,7 +328,7 @@ class DeepgramSageMakerTTSService(TTSService):
except Exception as e:
logger.error(f"{self} error sending Clear message: {e}")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis by sending Flush command.
This should be called when the LLM finishes a complete response to force
@@ -355,12 +355,12 @@ class DeepgramSageMakerTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
if not self._ttfb_started:
await self.start_ttfb_metrics()
self._ttfb_started = True
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
if not self._ttfb_started:
await self.start_ttfb_metrics()
self._ttfb_started = True
yield TTSStartedFrame(context_id=context_id)
self._context_id = context_id
await self._client.send_json({"type": "Speak", "text": text})

View File

@@ -26,8 +26,6 @@ from pipecat.frames.frames import (
LLMFullResponseEndFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
@@ -120,6 +118,7 @@ class DeepgramTTSService(WebsocketTTSService):
sample_rate=sample_rate,
pause_frame_processing=True,
push_stop_frames=True,
push_start_frame=True,
append_trailing_space=True,
settings=default_settings,
**kwargs,
@@ -130,7 +129,6 @@ class DeepgramTTSService(WebsocketTTSService):
self._encoding = encoding
self._receive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if the service can generate metrics.
@@ -267,7 +265,6 @@ class DeepgramTTSService(WebsocketTTSService):
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -299,7 +296,9 @@ class DeepgramTTSService(WebsocketTTSService):
if isinstance(message, bytes):
# Binary message contains audio data
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self.sample_rate, 1, context_id=self._context_id)
frame = TTSAudioRawFrame(
message, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
)
await self.push_frame(frame)
elif isinstance(message, str):
# Text message contains metadata or control messages
@@ -326,7 +325,7 @@ class DeepgramTTSService(WebsocketTTSService):
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis by sending Flush command.
This should be called when the LLM finishes a complete response to force
@@ -357,13 +356,8 @@ class DeepgramTTSService(WebsocketTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
# Store context_id for use in _receive_messages
self._context_id = context_id
# Send text message to Deepgram
# Note: We don't send Flush here - that should only be sent when the
# LLM finishes a complete response via flush_audio()
@@ -435,6 +429,8 @@ class DeepgramHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -492,7 +488,6 @@ class DeepgramHttpTTSService(TTSService):
raise Exception(f"HTTP {response.status}: {error_text}")
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
@@ -510,7 +505,5 @@ class DeepgramHttpTTSService(TTSService):
context_id=context_id,
)
yield TTSStoppedFrame(context_id=context_id)
except Exception as e:
yield ErrorFrame(f"Error getting audio: {str(e)}")

View File

@@ -46,9 +46,9 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import (
AudioContextTTSService,
TextAggregationMode,
TTSService,
WebsocketTTSService,
)
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -308,7 +308,7 @@ def calculate_word_times(
return (word_times, new_partial_word, new_partial_word_start_time)
class ElevenLabsTTSService(AudioContextTTSService):
class ElevenLabsTTSService(WebsocketTTSService):
"""ElevenLabs WebSocket-based TTS service with word timestamps.
Provides real-time text-to-speech using ElevenLabs' WebSocket streaming API.
@@ -479,7 +479,6 @@ class ElevenLabsTTSService(AudioContextTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -559,20 +558,15 @@ class ElevenLabsTTSService(AudioContextTTSService):
)
await self._disconnect()
await self._connect()
elif voice_settings_changed and self.has_active_audio_context():
elif voice_settings_changed:
logger.debug(
f"Voice settings changed ({changed.keys() & ElevenLabsTTSSettings.VOICE_SETTINGS_FIELDS}), "
f"closing current context to apply changes"
)
context_id = self.get_active_audio_context_id()
try:
if self._websocket:
await self._websocket.send(
json.dumps({"context_id": context_id, "close_context": True})
)
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self.reset_active_audio_context()
audio_contexts = self.get_audio_contexts()
if audio_contexts:
for ctx_id in audio_contexts:
await self._close_context(ctx_id)
if not url_changed:
# Reconnect applies all settings; only warn about fields not handled
@@ -610,13 +604,18 @@ class ElevenLabsTTSService(AudioContextTTSService):
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self):
"""Flush any pending audio and finalize the current context."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio and finalize the current context.
Args:
context_id: The specific context to flush. If None, falls back to the
currently active context.
"""
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
msg = {"context_id": context_id, "flush": True}
msg = {"context_id": flush_id, "flush": True}
await self._websocket.send(json.dumps(msg))
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -703,9 +702,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
if self._websocket:
logger.debug("Disconnecting from ElevenLabs")
# Close all contexts and the socket
if self.has_active_audio_context():
await self._websocket.send(json.dumps({"close_socket": True}))
await self._websocket.send(json.dumps({"close_socket": True}))
await self._websocket.close()
logger.debug("Disconnected from ElevenLabs")
except Exception as e:
@@ -737,6 +734,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
)
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._cumulative_time = 0.0
self._partial_word = ""
self._partial_word_start_time = 0.0
@@ -782,9 +780,6 @@ class ElevenLabsTTSService(AudioContextTTSService):
continue
if msg.get("audio"):
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=received_ctx_id)
await self.append_to_audio_context(received_ctx_id, frame)
@@ -845,9 +840,8 @@ class ElevenLabsTTSService(AudioContextTTSService):
logger.warning(f"{self} keepalive error: {e}")
break
async def _send_text(self, text: str):
async def _send_text(self, text: str, context_id: str):
"""Send text to the WebSocket for synthesis."""
context_id = self.get_active_audio_context_id()
if self._websocket and context_id:
msg = {"text": text, "context_id": context_id}
await self._websocket.send(json.dumps(msg))
@@ -870,16 +864,14 @@ class ElevenLabsTTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
self._cumulative_time = 0
self._partial_word = ""
self._partial_word_start_time = 0.0
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
# Initialize context with voice settings and pronunciation dictionaries
msg = {"text": " ", "context_id": context_id}
if self._voice_settings:
@@ -892,7 +884,7 @@ class ElevenLabsTTSService(AudioContextTTSService):
await self._websocket.send(json.dumps(msg))
logger.trace(f"Created new context {context_id}")
await self._send_text(text)
await self._send_text(text, context_id)
await self.start_tts_usage_metrics(text)
except Exception as e:
yield TTSStoppedFrame(context_id=context_id)
@@ -1046,7 +1038,7 @@ class ElevenLabsHttpTTSService(TTSService):
aggregate_sentences=aggregate_sentences,
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
push_start_frame=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -1266,8 +1258,6 @@ class ElevenLabsHttpTTSService(TTSService):
params["optimize_streaming_latency"] = self._settings.optimize_streaming_latency
try:
await self.start_ttfb_metrics()
async with self._session.post(
url, json=payload, headers=headers, params=params
) as response:
@@ -1278,10 +1268,6 @@ class ElevenLabsHttpTTSService(TTSService):
await self.start_tts_usage_metrics(text)
# Start TTS sequence
await self.start_word_timestamps()
yield TTSStartedFrame(context_id=context_id)
# Track the duration of this utterance based on the last character's end time
utterance_duration = 0
async for line in response.content:
@@ -1347,4 +1333,3 @@ class ElevenLabsHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
# Let the parent class handle TTSStoppedFrame

View File

@@ -209,6 +209,7 @@ class FishAudioTTSService(InterruptibleTTSService):
super().__init__(
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
sample_rate=sample_rate,
settings=default_settings,
@@ -219,7 +220,6 @@ class FishAudioTTSService(InterruptibleTTSService):
self._base_url = "wss://api.fish.audio/v1/tts/live"
self._websocket = None
self._receive_task = None
self._request_id = None
# Init-only audio format config (not runtime-updatable).
self._fish_sample_rate = 0 # Set in start()
@@ -341,11 +341,10 @@ class FishAudioTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._request_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any buffered audio by sending a flush event to Fish Audio."""
logger.trace(f"{self}: Flushing audio buffers")
if not self._websocket or self._websocket.state is State.CLOSED:
@@ -361,7 +360,6 @@ class FishAudioTTSService(InterruptibleTTSService):
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
self._request_id = None
async def _receive_messages(self):
async for message in self._get_websocket():
@@ -398,12 +396,6 @@ class FishAudioTTSService(InterruptibleTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
if not self._request_id:
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
self._request_id = str(uuid.uuid4())
# Send the text
text_message = {
"event": "text",

View File

@@ -34,8 +34,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import (
NOT_GIVEN,
@@ -655,6 +653,8 @@ class GoogleHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -803,8 +803,6 @@ class GoogleHttpTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Check if the voice is a Chirp voice (including Chirp 3) or Journey voice
is_chirp_voice = "chirp" in self._settings.voice.lower()
is_journey_voice = "journey" in self._settings.voice.lower()
@@ -840,8 +838,6 @@ class GoogleHttpTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
# Skip the first 44 bytes to remove the WAV header
audio_content = response.audio_content[44:]
@@ -855,8 +851,6 @@ class GoogleHttpTTSService(TTSService):
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
yield frame
yield TTSStoppedFrame(context_id=context_id)
except Exception as e:
error_message = f"TTS generation error: {str(e)}"
yield ErrorFrame(error=error_message)
@@ -967,8 +961,6 @@ class GoogleBaseTTSService(TTSService):
streaming_responses = await self._client.streaming_synthesize(request_generator())
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
audio_buffer = b""
first_chunk_for_ttfb = False
@@ -992,8 +984,6 @@ class GoogleBaseTTSService(TTSService):
if audio_buffer:
yield TTSAudioRawFrame(audio_buffer, self.sample_rate, 1, context_id=context_id)
yield TTSStoppedFrame(context_id=context_id)
class GoogleTTSService(GoogleBaseTTSService):
"""Google Cloud Text-to-Speech streaming service.
@@ -1096,6 +1086,8 @@ class GoogleTTSService(GoogleBaseTTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -1135,8 +1127,6 @@ class GoogleTTSService(GoogleBaseTTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Build voice selection params
if self._voice_cloning_key:
voice_clone_params = texttospeech_v1.VoiceCloneParams(
@@ -1352,6 +1342,8 @@ class GeminiTTSService(GoogleBaseTTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -1414,8 +1406,6 @@ class GeminiTTSService(GoogleBaseTTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Build voice selection params
if self._settings.multi_speaker and self._settings.speaker_configs:
# Multi-speaker mode

View File

@@ -19,11 +19,10 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import AudioContextTTSService
from pipecat.services.tts_service import WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -45,7 +44,7 @@ class GradiumTTSSettings(TTSSettings):
pass
class GradiumTTSService(AudioContextTTSService):
class GradiumTTSService(WebsocketTTSService):
"""Text-to-Speech service using Gradium's websocket API."""
_settings: GradiumTTSSettings
@@ -125,9 +124,9 @@ class GradiumTTSService(AudioContextTTSService):
super().__init__(
push_stop_frames=True,
push_start_frame=True,
push_text_frames=False,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=SAMPLE_RATE,
settings=default_settings,
**kwargs,
@@ -166,12 +165,9 @@ class GradiumTTSService(AudioContextTTSService):
self._warn_unhandled_updated_settings(changed)
return changed
def _build_msg(self, text: str = "") -> dict:
def _build_msg(self, text: str = "", context_id: str = "") -> dict:
"""Build JSON message for Gradium API."""
msg = {"text": text, "type": "text"}
context_id = self.get_active_audio_context_id()
if context_id:
msg["client_req_id"] = context_id
msg = {"text": text, "type": "text", "client_req_id": context_id}
return msg
async def start(self, frame: StartFrame):
@@ -280,15 +276,14 @@ class GradiumTTSService(AudioContextTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
try:
msg = {"type": "end_of_stream", "client_req_id": context_id}
msg = {"type": "end_of_stream", "client_req_id": flush_id}
await self._websocket.send(json.dumps(msg))
self.reset_active_audio_context()
except ConnectionClosedOK:
logger.debug(f"{self}: connection closed normally during flush")
except Exception as e:
@@ -326,8 +321,6 @@ class GradiumTTSService(AudioContextTTSService):
if msg["type"] == "audio":
if not ctx_id or not self.audio_context_available(ctx_id):
continue
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["audio"]),
sample_rate=self.sample_rate,
@@ -369,12 +362,7 @@ class GradiumTTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
await self.create_audio_context(context_id)
msg = self._build_msg(text=text)
msg = self._build_msg(text=text, context_id=context_id)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:

View File

@@ -18,8 +18,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -140,6 +138,8 @@ class GroqTTSService(TTSService):
super().__init__(
pause_frame_processing=True,
push_start_frame=True,
push_stop_frames=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -171,9 +171,6 @@ class GroqTTSService(TTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
measuring_ttfb = True
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
try:
response = await self._client.audio.speech.create(
model=self._settings.model,
@@ -198,5 +195,3 @@ class GroqTTSService(TTSService):
yield TTSAudioRawFrame(bytes, frame_rate, channels, context_id=context_id)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -22,7 +22,6 @@ from pipecat.frames.frames import (
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
@@ -166,7 +165,7 @@ class HumeTTSService(TTSService):
sample_rate=sample_rate,
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
push_start_frame=True,
settings=default_settings,
**kwargs,
)
@@ -181,7 +180,6 @@ class HumeTTSService(TTSService):
# Track cumulative time for word timestamps across utterances
self._cumulative_time = 0.0
self._started = False
def can_generate_metrics(self) -> bool:
"""Can generate metrics.
@@ -203,7 +201,6 @@ class HumeTTSService(TTSService):
def _reset_state(self):
"""Reset internal state variables."""
self._cumulative_time = 0.0
self._started = False
async def stop(self, frame: EndFrame) -> None:
"""Stop the service and cleanup resources.
@@ -310,15 +307,8 @@ class HumeTTSService(TTSService):
# Request raw PCM chunks in the streaming JSON
pcm_fmt = FormatPcm(type="pcm")
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
# Start TTS sequence if not already started
if not self._started:
await self.start_word_timestamps()
yield TTSStartedFrame(context_id=context_id)
self._started = True
try:
# Instant mode is always enabled here (not user-configurable)
# Hume emits mono PCM at 48 kHz; downstream can resample if needed.
@@ -395,4 +385,3 @@ class HumeTTSService(TTSService):
finally:
# Ensure TTFB timer is stopped even on early failures
await self.stop_ttfb_metrics()
# Let the parent class handle TTSStoppedFrame via push_stop_frames

View File

@@ -62,7 +62,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import AudioContextTTSService, TextAggregationMode, TTSService
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -212,7 +212,7 @@ class InworldHttpTTSService(TTSService):
super().__init__(
push_text_frames=False,
push_stop_frames=True,
supports_word_timestamps=True,
push_start_frame=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -359,11 +359,6 @@ class InworldHttpTTSService(TTSService):
}
try:
await self.start_ttfb_metrics()
await self.start_word_timestamps()
yield TTSStartedFrame(context_id=context_id)
async with self._session.post(
self._base_url, json=payload, headers=headers
) as response:
@@ -514,7 +509,7 @@ class InworldHttpTTSService(TTSService):
)
class InworldTTSService(AudioContextTTSService):
class InworldTTSService(WebsocketTTSService):
"""Inworld AI WebSocket-based TTS service.
Uses bidirectional WebSocket for lower latency streaming. Supports multiple
@@ -650,7 +645,6 @@ class InworldTTSService(AudioContextTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
sample_rate=sample_rate,
aggregate_sentences=aggregate_sentences,
text_aggregation_mode=text_aggregation_mode,
@@ -719,17 +713,17 @@ class InworldTTSService(AudioContextTTSService):
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio without closing the context.
This triggers synthesis of all accumulated text in the buffer while
keeping the context open for subsequent text. The context is only
closed on interruption, disconnect, or end of session.
"""
context_id = self.get_active_audio_context_id()
if context_id and self._websocket:
logger.trace(f"Flushing audio for context {context_id}")
await self._send_flush(context_id)
flush_id = context_id or self.get_active_audio_context_id()
if flush_id and self._websocket:
logger.trace(f"Flushing audio for context {flush_id}")
await self._send_flush(flush_id)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame and handle state changes.
@@ -899,12 +893,10 @@ class InworldTTSService(AudioContextTTSService):
if self._websocket:
logger.debug("Disconnecting from Inworld WebSocket TTS")
context_id = self.get_active_audio_context_id()
if context_id:
try:
await self._send_close_context(context_id)
except Exception:
pass
audio_contexts = self.get_audio_contexts()
if audio_contexts:
for ctx_id in audio_contexts:
await self._send_close_context(ctx_id)
await self._websocket.close()
logger.debug("Disconnected from Inworld WebSocket TTS")
except Exception as e:
@@ -934,10 +926,7 @@ class InworldTTSService(AudioContextTTSService):
for k in ["contextCreated", "audioChunk", "flushCompleted", "contextClosed"]
if k in result
]
logger.debug(
f"{self}: Received message types={msg_types}, ctx_id={ctx_id}, "
f"current_ctx={self.get_active_audio_context_id()}, available={self.audio_context_available(ctx_id) if ctx_id else 'N/A'}"
)
logger.debug(f"{self}: Received message types={msg_types}, ctx_id={ctx_id}")
# Check for errors
status = result.get("status", {})
@@ -948,9 +937,7 @@ class InworldTTSService(AudioContextTTSService):
# Handle "Context not found" error (code 5)
# This can happen when a keepalive message is sent but no context is available.
if error_code == 5 and "not found" in error_msg.lower():
logger.debug(
f"{self}: Context {ctx_id or self.get_active_audio_context_id()} not found."
)
logger.debug(f"{self}: Context {ctx_id} not found.")
continue
# For other errors, push error frame
@@ -961,17 +948,10 @@ class InworldTTSService(AudioContextTTSService):
await self.push_error(error_msg=str(msg["error"]))
continue
# Check if this message belongs to an available context.
# If the context isn't available but matches our current context ID,
# recreate it (handles race conditions during interruption recovery).
# If the context isn't available recreate it (handles race conditions during interruption recovery).
if ctx_id and not self.audio_context_available(ctx_id):
if self.get_active_audio_context_id() == ctx_id:
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
await self.create_audio_context(ctx_id)
else:
# This is a message from an old/closed context - skip it
logger.trace(f"{self}: Skipping message from unavailable context: {ctx_id}")
continue
logger.trace(f"{self}: Recreating audio context for current context: {ctx_id}")
await self.create_audio_context(ctx_id)
# Process audio chunk
audio_chunk = result.get("audioChunk", {})
@@ -979,8 +959,6 @@ class InworldTTSService(AudioContextTTSService):
if audio_b64:
logger.trace(f"{self}: Processing audio chunk for context {ctx_id}")
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
audio = base64.b64decode(audio_b64)
if len(audio) > 44 and audio.startswith(b"RIFF"):
audio = audio[44:]
@@ -1012,12 +990,8 @@ class InworldTTSService(AudioContextTTSService):
if "contextClosed" in result:
logger.trace(f"{self}: Context closed on server: {ctx_id}")
await self.stop_ttfb_metrics()
# Only reset if this is our current context
if ctx_id == self.get_active_audio_context_id():
self.reset_active_audio_context()
if ctx_id and self.audio_context_available(ctx_id):
await self.remove_audio_context(ctx_id)
await self.add_word_timestamps([("TTSStoppedFrame", 0), ("Reset", 0)], ctx_id)
await self.remove_audio_context(ctx_id)
async def _keepalive_task_handler(self):
"""Send periodic keepalive messages to maintain WebSocket connection."""
@@ -1128,10 +1102,10 @@ class InworldTTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
await self.create_audio_context(context_id)
await self._send_context(context_id)
await self._send_text(context_id, text)

View File

@@ -20,8 +20,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -170,6 +168,8 @@ class KokoroTTSService(TTSService):
default_settings.apply_update(settings)
super().__init__(
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -212,9 +212,7 @@ class KokoroTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
stream = self._kokoro.create_stream(
text, voice=self._settings.voice, lang=self._settings.language, speed=1.0
@@ -238,4 +236,3 @@ class KokoroTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -143,6 +143,7 @@ class LmntTTSService(InterruptibleTTSService):
super().__init__(
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
sample_rate=sample_rate,
settings=default_settings,
@@ -152,7 +153,6 @@ class LmntTTSService(InterruptibleTTSService):
self._api_key = api_key
self._output_format = "raw"
self._receive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -289,7 +289,6 @@ class LmntTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Error disconnecting from LMNT: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -299,7 +298,7 @@ class LmntTTSService(InterruptibleTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
if not self._websocket or self._websocket.state is State.CLOSED:
return
@@ -315,7 +314,7 @@ class LmntTTSService(InterruptibleTTSService):
audio=message,
sample_rate=self.sample_rate,
num_channels=1,
context_id=self._context_id,
context_id=self.get_active_audio_context_id(),
)
await self.push_frame(frame)
else:
@@ -347,11 +346,6 @@ class LmntTTSService(InterruptibleTTSService):
await self._connect()
try:
await self.start_ttfb_metrics()
# Store context_id for use in _receive_messages
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
# Send text to LMNT
await self._get_websocket().send(json.dumps({"text": text}))
# Force synthesis

View File

@@ -23,8 +23,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -305,6 +303,8 @@ class MiniMaxHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -402,8 +402,6 @@ class MiniMaxHttpTTSService(TTSService):
payload["language_boost"] = self._settings.language_boost
try:
await self.start_ttfb_metrics()
async with self._session.post(
self._base_url, headers=headers, json=payload
) as response:
@@ -413,7 +411,6 @@ class MiniMaxHttpTTSService(TTSService):
return
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
# Process the streaming response
buffer = bytearray()
@@ -490,4 +487,3 @@ class MiniMaxHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -180,6 +180,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
aggregate_sentences=aggregate_sentences,
text_aggregation_mode=text_aggregation_mode,
push_stop_frames=True,
push_start_frame=True,
stop_frame_timeout_s=2.0,
sample_rate=sample_rate,
settings=default_settings,
@@ -188,12 +189,8 @@ class NeuphonicTTSService(InterruptibleTTSService):
self._api_key = api_key
self._url = url
self._cumulative_time = 0
self._receive_task = None
self._keepalive_task = None
self._context_id: Optional[str] = None
self._encoding = encoding
self._sampling_rate = sample_rate
@@ -252,7 +249,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis by sending stop command."""
if self._websocket:
msg = {"text": "<STOP>"}
@@ -358,7 +355,6 @@ class NeuphonicTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -372,7 +368,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
audio = base64.b64decode(msg["data"]["audio"])
frame = TTSAudioRawFrame(
audio, self.sample_rate, 1, context_id=self._context_id
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
)
await self.push_frame(frame)
@@ -415,12 +411,6 @@ class NeuphonicTTSService(InterruptibleTTSService):
await self._connect()
try:
await self.start_ttfb_metrics()
# Store context_id for use in _receive_messages
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
self._cumulative_time = 0
await self._send_text(text)
await self.start_tts_usage_metrics(text)
except Exception as e:
@@ -523,6 +513,8 @@ class NeuphonicHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_stop_frames=True,
push_start_frame=True,
settings=default_settings,
**kwargs,
)
@@ -559,7 +551,7 @@ class NeuphonicHttpTTSService(TTSService):
"""
await super().start(frame)
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis.
Note:
@@ -633,8 +625,6 @@ class NeuphonicHttpTTSService(TTSService):
payload["voice_id"] = self._settings.voice
try:
await self.start_ttfb_metrics()
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
@@ -643,7 +633,6 @@ class NeuphonicHttpTTSService(TTSService):
return
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
# Process SSE stream line by line
async for line in response.content:
@@ -681,4 +670,3 @@ class NeuphonicHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -28,8 +28,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -145,6 +143,8 @@ class NvidiaTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -271,9 +271,6 @@ class NvidiaTTSService(TTSService):
assert self._service is not None, "TTS service not initialized"
assert self._config is not None, "Synthesis configuration not created"
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
logger.debug(f"{self}: Generating TTS [{text}]")
responses = await asyncio.to_thread(read_audio_responses)
@@ -289,7 +286,6 @@ class NvidiaTTSService(TTSService):
yield frame
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame(context_id=context_id)
except asyncio.TimeoutError as e:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")

View File

@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -194,6 +192,8 @@ class OpenAITTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -234,8 +234,6 @@ class OpenAITTSService(TTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Setup API parameters
create_params = {
"input": text,
@@ -267,12 +265,10 @@ class OpenAITTSService(TTSService):
CHUNK_SIZE = self.chunk_size
yield TTSStartedFrame(context_id=context_id)
async for chunk in r.iter_bytes(CHUNK_SIZE):
if len(chunk) > 0:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
yield frame
yield TTSStoppedFrame(context_id=context_id)
except BadRequestError as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -17,8 +17,6 @@ from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -91,6 +89,8 @@ class PiperTTSService(TTSService):
default_settings.apply_update(settings)
super().__init__(
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -159,12 +159,8 @@ class PiperTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
async for frame in self._stream_audio_frames_from_iterator(
async_iterator(self._voice.synthesize(text)),
in_sample_rate=self._voice.config.sample_rate,
@@ -178,7 +174,6 @@ class PiperTTSService(TTSService):
finally:
logger.debug(f"{self}: Finished TTS [{text}]")
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)
# This assumes a running TTS service running:
@@ -244,6 +239,8 @@ class PiperHttpTTSService(TTSService):
default_settings.apply_update(settings)
super().__init__(
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -279,8 +276,6 @@ class PiperHttpTTSService(TTSService):
"Content-Type": "application/json",
}
try:
await self.start_ttfb_metrics()
data = {
"text": text,
"voice": self._settings.voice,
@@ -296,8 +291,6 @@ class PiperHttpTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
async for frame in self._stream_audio_frames_from_iterator(
@@ -311,4 +304,3 @@ class PiperHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -24,7 +24,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import AudioContextTTSService
from pipecat.services.tts_service import WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -43,7 +43,7 @@ class ResembleAITTSSettings(TTSSettings):
pass
class ResembleAITTSService(AudioContextTTSService):
class ResembleAITTSService(WebsocketTTSService):
"""Resemble AI TTS service with WebSocket streaming and word timestamps.
Provides text-to-speech using Resemble AI's streaming WebSocket API.
@@ -103,7 +103,6 @@ class ResembleAITTSService(AudioContextTTSService):
super().__init__(
sample_rate=sample_rate,
reuse_context_id_within_turn=False,
supports_word_timestamps=True,
settings=default_settings,
**kwargs,
)
@@ -268,7 +267,7 @@ class ResembleAITTSService(AudioContextTTSService):
"""
pass
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio and finalize the current context."""
logger.trace(f"{self}: flushing audio")
# For Resemble AI, we just wait for the audio_end message
@@ -297,9 +296,6 @@ class ResembleAITTSService(AudioContextTTSService):
continue
if msg_type == "audio":
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
# Decode base64 audio content
audio_content = msg.get("audio_content", "")
if not audio_content:
@@ -447,14 +443,14 @@ class ResembleAITTSService(AudioContextTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
# Map request_id to context_id for tracking
self._request_id_to_context[self._request_id_counter] = context_id
await self.create_audio_context(context_id)
msg = self._build_msg(text=text)
try:

View File

@@ -33,10 +33,10 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import (
AudioContextTTSService,
InterruptibleTTSService,
TextAggregationMode,
TTSService,
WebsocketTTSService,
)
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
@@ -123,7 +123,7 @@ class RimeNonJsonTTSSettings(TTSSettings):
_aliases: ClassVar[Dict[str, str]] = {"speaker": "voice"}
class RimeTTSService(AudioContextTTSService):
class RimeTTSService(WebsocketTTSService):
"""Text-to-Speech service using Rime's websocket API.
Uses Rime's websocket JSON API to convert text to speech with word-level timing
@@ -276,7 +276,6 @@ class RimeTTSService(AudioContextTTSService):
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
supports_word_timestamps=True,
append_trailing_space=True,
sample_rate=sample_rate,
settings=default_settings,
@@ -408,9 +407,9 @@ class RimeTTSService(AudioContextTTSService):
return changed
def _build_msg(self, text: str = "") -> dict:
def _build_msg(self, text: str = "", context_id: str = "") -> dict:
"""Build JSON message for Rime API."""
msg = {"text": text, "contextId": self.get_active_audio_context_id()}
msg = {"text": text, "contextId": context_id}
if self._extra_msg_fields:
msg |= self._extra_msg_fields
self._extra_msg_fields = {}
@@ -557,15 +556,14 @@ class RimeTTSService(AudioContextTTSService):
return word_pairs
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
context_id = self.get_active_audio_context_id()
if not context_id or not self._websocket:
flush_id = context_id or self.get_active_audio_context_id()
if not flush_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
await self._get_websocket().send(json.dumps({"operation": "flush"}))
self.reset_active_audio_context()
async def _receive_messages(self):
"""Process incoming websocket messages."""
@@ -578,8 +576,6 @@ class RimeTTSService(AudioContextTTSService):
context_id = msg["contextId"]
if msg["type"] == "chunk":
# Process audio chunk
await self.stop_ttfb_metrics()
await self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self.sample_rate,
@@ -638,13 +634,13 @@ class RimeTTSService(AudioContextTTSService):
await self._connect()
try:
if not self.has_active_audio_context():
if not self.audio_context_available(context_id):
await self.create_audio_context(context_id)
await self.start_ttfb_metrics()
yield TTSStartedFrame(context_id=context_id)
self._cumulative_time = 0
await self.create_audio_context(context_id)
msg = self._build_msg(text=text)
msg = self._build_msg(text=text, context_id=context_id)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:
@@ -773,6 +769,8 @@ class RimeHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_stop_frames=True,
push_start_frame=True,
settings=default_settings,
**kwargs,
)
@@ -844,8 +842,6 @@ class RimeHttpTTSService(TTSService):
need_to_strip_wav_header = False
try:
await self.start_ttfb_metrics()
async with self._session.post(
self._base_url, json=payload, headers=headers
) as response:
@@ -856,8 +852,6 @@ class RimeHttpTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
async for frame in self._stream_audio_frames_from_iterator(
@@ -872,7 +866,6 @@ class RimeHttpTTSService(TTSService):
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)
class RimeNonJsonTTSService(InterruptibleTTSService):
@@ -1005,6 +998,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
aggregate_sentences=aggregate_sentences,
text_aggregation_mode=text_aggregation_mode,
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
append_trailing_space=True,
settings=default_settings,
@@ -1022,7 +1016,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
self._settings.extra.update(params.extra)
self._receive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -1138,7 +1131,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -1148,7 +1140,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
if not self._websocket:
return
@@ -1168,7 +1160,7 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
audio=message,
sample_rate=self.sample_rate,
num_channels=1,
context_id=self._context_id,
context_id=self.get_active_audio_context_id(),
)
await self.push_frame(frame)
except Exception as e:
@@ -1190,10 +1182,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
try:
await self.start_ttfb_metrics()
# Store context_id for use in _receive_messages
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
# Send bare text (not JSON)
await self._get_websocket().send(text)
await self.start_tts_usage_metrics(text)

View File

@@ -524,6 +524,8 @@ class SarvamHttpTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_stop_frames=True,
push_start_frame=True,
settings=default_settings,
**kwargs,
)
@@ -573,8 +575,6 @@ class SarvamHttpTTSService(TTSService):
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_ttfb_metrics()
# Build payload with common parameters
payload = {
"text": text,
@@ -606,8 +606,6 @@ class SarvamHttpTTSService(TTSService):
url = f"{self._base_url}/text-to-speech"
yield TTSStartedFrame(context_id=context_id)
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
@@ -645,7 +643,6 @@ class SarvamHttpTTSService(TTSService):
yield ErrorFrame(error=f"Error generating TTS: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame(context_id=context_id)
class SarvamTTSService(InterruptibleTTSService):
@@ -951,6 +948,7 @@ class SarvamTTSService(InterruptibleTTSService):
push_text_frames=True,
pause_frame_processing=True,
push_stop_frames=True,
push_start_frame=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
@@ -967,7 +965,6 @@ class SarvamTTSService(InterruptibleTTSService):
self._receive_task = None
self._keepalive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -1018,7 +1015,7 @@ class SarvamTTSService(InterruptibleTTSService):
await super().cancel(frame)
await self._disconnect()
async def flush_audio(self):
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis by sending flush command."""
try:
if self._websocket:
@@ -1151,7 +1148,6 @@ class SarvamTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -1170,7 +1166,7 @@ class SarvamTTSService(InterruptibleTTSService):
await self.stop_ttfb_metrics()
audio = base64.b64decode(msg["data"]["audio"])
frame = TTSAudioRawFrame(
audio, self.sample_rate, 1, context_id=self._context_id
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
)
await self.push_frame(frame)
elif msg.get("type") == "error":
@@ -1224,10 +1220,6 @@ class SarvamTTSService(InterruptibleTTSService):
await self._connect()
try:
await self.start_ttfb_metrics()
# Store context_id for use in _receive_messages
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
await self._send_text(text)
await self.start_tts_usage_metrics(text)
except Exception as e:

View File

@@ -19,8 +19,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -135,6 +133,8 @@ class SpeechmaticsTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -185,9 +185,6 @@ class SpeechmaticsTTSService(TTSService):
url = _get_endpoint_url(self._base_url, self._settings.voice, self.sample_rate)
try:
# Start TTS TTFB metrics
await self.start_ttfb_metrics()
# Track attempt
attempt = 0
@@ -238,9 +235,6 @@ class SpeechmaticsTTSService(TTSService):
# Update Pipecat metrics
await self.start_tts_usage_metrics(text)
# Emit the TTS started frame
yield TTSStartedFrame(context_id=context_id)
# Process the response in streaming chunks
first_chunk = True
buffer = b""
@@ -277,8 +271,7 @@ class SpeechmaticsTTSService(TTSService):
except Exception as e:
yield ErrorFrame(error=f"Error generating TTS: {e}")
finally:
# Emit the TTS stopped frame
yield TTSStoppedFrame(context_id=context_id)
await self.stop_ttfb_metrics()
def _get_endpoint_url(base_url: str, voice: str, sample_rate: int) -> str:

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ text-to-speech synthesis using local Docker deployment.
"""
from dataclasses import dataclass
from typing import AsyncGenerator, Dict, Optional
from typing import Any, AsyncGenerator, Dict, Optional
import aiohttp
from loguru import logger
@@ -22,8 +22,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
@@ -132,6 +130,8 @@ class XTTSService(TTSService):
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
@@ -213,8 +213,6 @@ class XTTSService(TTSService):
"stream_chunk_size": 20,
}
await self.start_ttfb_metrics()
async with self._aiohttp_session.post(url, json=payload) as r:
if r.status != 200:
text = await r.text()
@@ -223,8 +221,6 @@ class XTTSService(TTSService):
await self.start_tts_usage_metrics(text)
yield TTSStartedFrame(context_id=context_id)
CHUNK_SIZE = self.chunk_size
buffer = bytearray()
@@ -262,5 +258,3 @@ class XTTSService(TTSService):
resampled_audio, self.sample_rate, 1, context_id=context_id
)
yield frame
yield TTSStoppedFrame(context_id=context_id)

View File

@@ -199,13 +199,13 @@ async def run_test(
#
# Down frames
#
received_down_frames: Sequence[Frame] = []
if expected_down_frames is not None:
while not received_down.empty():
frame = await received_down.get()
if not isinstance(frame, EndFrame) or not send_end_frame:
received_down_frames.append(frame)
received_down_frames: list[Frame] = []
while not received_down.empty():
frame = await received_down.get()
if not isinstance(frame, EndFrame) or not send_end_frame:
received_down_frames.append(frame)
if expected_down_frames is not None:
down_frames_printed = "["
for frame in received_down_frames:
down_frames_printed += f"{frame.__class__.__name__}, "
@@ -225,12 +225,12 @@ async def run_test(
#
# Up frames
#
received_up_frames: Sequence[Frame] = []
if expected_up_frames is not None:
while not received_up.empty():
frame = await received_up.get()
received_up_frames.append(frame)
received_up_frames: list[Frame] = []
while not received_up.empty():
frame = await received_up.get()
received_up_frames.append(frame)
if expected_up_frames is not None:
print("received UP frames =", received_up_frames)
print("expected UP frames =", expected_up_frames)

View File

@@ -44,12 +44,15 @@ from pipecat.frames.frames import (
StartFrame,
SystemFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.transports.base_transport import TransportParams
from pipecat.utils.time import nanoseconds_to_seconds
BOT_VAD_STOP_SECS = 0.35
# Only used as a fallback
BOT_VAD_STOP_FALLBACK_SECS = 3
class BaseOutputTransport(FrameProcessor):
@@ -354,6 +357,8 @@ class BaseOutputTransport(FrameProcessor):
await sender.handle_sync_frame(frame)
elif isinstance(frame, MixerControlFrame):
await sender.handle_mixer_control_frame(frame)
elif isinstance(frame, TTSStoppedFrame):
await sender.handle_sync_frame(frame)
elif frame.pts:
await sender.handle_timed_frame(frame)
else:
@@ -412,6 +417,8 @@ class BaseOutputTransport(FrameProcessor):
# Indicates if the bot is currently speaking.
self._bot_speaking = False
# Indicates if TTS audio has been received since the last stop.
self._tts_audio_received = False
# Last time a BotSpeakingFrame was pushed.
self._bot_speaking_frame_time = 0
# How often a BotSpeakingFrame should be pushed (value should be
@@ -639,6 +646,7 @@ class BaseOutputTransport(FrameProcessor):
return
self._bot_speaking = False
self._tts_audio_received = False
# Clean audio buffer (there could be tiny left overs if not multiple
# to our output chunk size).
@@ -682,6 +690,9 @@ class BaseOutputTransport(FrameProcessor):
async def _handle_bot_speech(self, frame: Frame):
# TTS case.
if isinstance(frame, TTSAudioRawFrame):
# We will only trigger bot stopped speaking based on the TTSStoppedFrame,
# if we have received audio from TTS
self._tts_audio_received = True
await self._bot_currently_speaking()
# Speech stream case.
elif isinstance(frame, SpeechOutputAudioRawFrame):
@@ -703,6 +714,12 @@ class BaseOutputTransport(FrameProcessor):
await self._transport.send_message(frame)
elif isinstance(frame, OutputDTMFFrame):
await self._transport.write_dtmf(frame)
elif isinstance(frame, TTSStoppedFrame):
# We will only trigger bot stopped speaking based on the TTSStoppedFrame,
# if we have received audio from TTS
if self._tts_audio_received:
logger.debug("Bot stopped speaking based on TTSStoppedFrame")
await self._bot_stopped_speaking()
else:
await self._transport.write_transport_frame(frame)
@@ -722,7 +739,7 @@ class BaseOutputTransport(FrameProcessor):
yield frame
self._audio_queue.task_done()
except asyncio.TimeoutError:
# Notify the bot stopped speaking upstream if necessary.
# Fallback: notify the bot stopped speaking upstream if necessary based on timeout.
await self._bot_stopped_speaking()
async def with_mixer(vad_stop_secs: float) -> AsyncGenerator[Frame, None]:
@@ -737,7 +754,7 @@ class BaseOutputTransport(FrameProcessor):
yield frame
self._audio_queue.task_done()
except asyncio.QueueEmpty:
# Notify the bot stopped speaking upstream if necessary.
# Fallback: notify the bot stopped speaking upstream if necessary based on timeout.
diff_time = time.time() - last_frame_time
if diff_time > vad_stop_secs:
await self._bot_stopped_speaking()
@@ -755,9 +772,9 @@ class BaseOutputTransport(FrameProcessor):
await asyncio.sleep(0)
if self._mixer:
return with_mixer(BOT_VAD_STOP_SECS)
return with_mixer(BOT_VAD_STOP_FALLBACK_SECS)
else:
return without_mixer(BOT_VAD_STOP_SECS)
return without_mixer(BOT_VAD_STOP_FALLBACK_SECS)
async def _send_silence(self, secs: int):
if secs <= 0:

View File

@@ -77,28 +77,36 @@ async def test_run_piper_tts_success(aiohttp_client):
TTSSpeakFrame(text="Hello world."),
]
expected_returned_frames = [
AggregatedTextFrame,
TTSStartedFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
TTSTextFrame,
]
frames_received = await run_test(
tts_service,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
down_frames = frames_received[0]
frame_types = [type(f) for f in down_frames]
# Verify key frames are present
assert AggregatedTextFrame in frame_types
assert TTSStartedFrame in frame_types
assert TTSStoppedFrame in frame_types
assert TTSTextFrame in frame_types
# Verify ordering: Started → audio → Stopped → Text
started_idx = frame_types.index(TTSStartedFrame)
stopped_idx = frame_types.index(TTSStoppedFrame)
text_idx = frame_types.index(TTSTextFrame)
assert started_idx < text_idx < stopped_idx, (
"Expected: TTSStartedFrame < TTSTextFrame < TTSStoppedFrame"
)
# Frames between Started and Stopped must all be audio or text
for i in range(started_idx + 1, stopped_idx):
assert frame_types[i] in (TTSAudioRawFrame, TTSTextFrame), (
f"Unexpected frame type between Started and Stopped: {frame_types[i]}"
)
# All audio frames have correct sample rate
audio_frames = [f for f in down_frames if isinstance(f, TTSAudioRawFrame)]
assert len(audio_frames) >= 1, "Expected at least one audio frame"
for a_frame in audio_frames:
assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)"
@@ -128,7 +136,7 @@ async def test_run_piper_tts_error(aiohttp_client):
TTSSpeakFrame(text="Error case.", append_to_context=False),
]
expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame]
expected_down_frames = [AggregatedTextFrame, TTSStartedFrame, TTSTextFrame, TTSStoppedFrame]
expected_up_frames = [ErrorFrame]