Fix TTFB metric and add multi-context WebSocket support for Async TTS

This commit is contained in:
Ashot
2025-12-23 16:35:45 +04:00
parent 86ed485711
commit 9cdbc56be3
2 changed files with 259 additions and 73 deletions

View File

@@ -6,6 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
<!-- towncrier release notes start -->
## [Unreleased]
### Changed
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.
### Fixed
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
## [0.0.99] - 2026-01-13

View File

@@ -9,8 +9,9 @@
import asyncio
import base64
import json
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, Optional, Dict
import uuid
import aiohttp
from loguru import logger
from pydantic import BaseModel
@@ -27,7 +28,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
from pipecat.services.tts_service import WebsocketTTSService, TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -72,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]:
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
class AsyncAITTSService(InterruptibleTTSService):
class AsyncAITTSService(WebsocketTTSService):
"""Async TTS service with WebSocket streaming.
Provides text-to-speech using Async's streaming WebSocket API.
@@ -126,6 +127,10 @@ class AsyncAITTSService(InterruptibleTTSService):
**kwargs,
)
self._contexts: Dict[str, asyncio.Queue] = {}
self._audio_context_task = None
self._context_id = None
params = params or AsyncAITTSService.InputParams()
self._api_key = api_key
@@ -148,6 +153,148 @@ class AsyncAITTSService(InterruptibleTTSService):
self._receive_task = None
self._keepalive_task = None
self._started = False
async def create_audio_context(self, context_id: str):
"""Create a new audio context for grouping related audio.
Args:
context_id: Unique identifier for the audio context.
"""
await self._contexts_queue.put(context_id)
self._contexts[context_id] = asyncio.Queue()
logger.trace(f"{self} created audio context {context_id}")
async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame):
"""Append audio to an existing context.
Args:
context_id: The context to append audio to.
frame: The audio frame to append.
"""
if self.audio_context_available(context_id):
logger.trace(f"{self} appending audio {frame} to audio context {context_id}")
await self._contexts[context_id].put(frame)
else:
logger.warning(f"{self} unable to append audio to context {context_id}")
async def remove_audio_context(self, context_id: str):
"""Remove an existing audio context.
Args:
context_id: The context to remove.
"""
if self.audio_context_available(context_id):
# We just mark the audio context for deletion by appending
# None. Once we reach None while handling audio we know we can
# safely remove the context.
logger.trace(f"{self} marking audio context {context_id} for deletion")
await self._contexts[context_id].put(None)
else:
logger.warning(f"{self} unable to remove context {context_id}")
def audio_context_available(self, context_id: str) -> bool:
"""Check whether the given audio context is registered.
Args:
context_id: The context ID to check.
Returns:
True if the context exists and is available.
"""
return context_id in self._contexts
async def start(self, frame: StartFrame):
"""Start the Async TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._create_audio_context_task()
self._settings["output_format"]["sample_rate"] = self.sample_rate
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Async TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
if self._audio_context_task:
# Indicate no more audio contexts are available. this will end the
# task cleanly after all contexts have been processed.
await self._contexts_queue.put(None)
await self._audio_context_task
self._audio_context_task = None
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Async TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._stop_audio_context_task()
await self._disconnect()
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self._stop_audio_context_task()
self._create_audio_context_task()
def _create_audio_context_task(self):
if not self._audio_context_task:
self._contexts_queue = asyncio.Queue()
self._contexts: Dict[str, asyncio.Queue] = {}
self._audio_context_task = self.create_task(self._audio_context_task_handler())
async def _stop_audio_context_task(self):
if self._audio_context_task:
await self.cancel_task(self._audio_context_task)
self._audio_context_task = None
async def _audio_context_task_handler(self):
"""In this task we process audio contexts in order."""
running = True
while running:
context_id = await self._contexts_queue.get()
if context_id:
# Process the audio context until the context doesn't have more
# audio available (i.e. we find None).
await self._handle_audio_context(context_id)
# We just finished processing the context, so we can safely remove it.
del self._contexts[context_id]
# Append some silence between sentences.
silence = b"\x00" * self.sample_rate
frame = TTSAudioRawFrame(
audio=silence, sample_rate=self.sample_rate, num_channels=1
)
await self.push_frame(frame)
else:
running = False
self._contexts_queue.task_done()
async def _handle_audio_context(self, context_id: str):
# If we don't receive any audio during this time, we consider the context finished.
AUDIO_CONTEXT_TIMEOUT = 3.0
queue = self._contexts[context_id]
running = True
while running:
try:
frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT)
if frame:
await self.push_frame(frame)
running = frame is not None
except asyncio.TimeoutError:
# We didn't get audio, so let's consider this context finished.
logger.trace(f"{self} time out on audio context {context_id}")
break
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -168,38 +315,10 @@ class AsyncAITTSService(InterruptibleTTSService):
"""
return language_to_async_language(language)
def _build_msg(self, text: str = "", force: bool = False) -> str:
msg = {"transcript": text, "force": force}
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
msg = {"transcript": text, "context_id": context_id, "force": force}
return json.dumps(msg)
async def start(self, frame: StartFrame):
"""Start the Async TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._settings["output_format"]["sample_rate"] = self.sample_rate
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Async TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Async TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
await super()._connect()
@@ -253,11 +372,16 @@ class AsyncAITTSService(InterruptibleTTSService):
if self._websocket:
logger.debug("Disconnecting from Async")
# Close all contexts and the socket
if self._context_id:
await self._websocket.send(json.dumps({"terminate": True}))
await self._websocket.close()
logger.debug("Disconnected from Async")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
logger.error(f"{self} error closing websocket: {e}")
finally:
self._websocket = None
self._context_id = None
self._started = False
await self._call_event_handler("on_disconnected")
@@ -268,10 +392,10 @@ class AsyncAITTSService(InterruptibleTTSService):
async def flush_audio(self):
"""Flush any pending audio."""
if not self._websocket:
if not self._context_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
msg = self._build_msg(text=" ", force=True)
msg = self._build_msg(text=" ", context_id=self._context_id, force=True)
await self._websocket.send(msg)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -291,35 +415,70 @@ class AsyncAITTSService(InterruptibleTTSService):
if not msg:
continue
elif msg.get("audio"):
received_ctx_id = msg.get("context_id")
# Handle final messages first, regardless of context availability
# At the moment, this message is received AFTER the close_context message is
# sent, so it doesn't serve any functional purpose. For now, we'll just log it.
if msg.get("final") is True:
logger.trace(f"Received final message for context {received_ctx_id}")
continue
# Check if this message belongs to the current context.
if not self.audio_context_available(received_ctx_id):
if self._context_id == received_ctx_id:
logger.debug(
f"Received a delayed message, recreating the context: {self._context_id}"
)
await self.create_audio_context(self._context_id)
else:
# This can happen if a message is received _after_ we have closed a context
# due to user interruption but _before_ the `isFinal` message for the context
# is received.
logger.debug(f"Ignoring message from unavailable context: {received_ctx_id}")
continue
if msg.get("audio"):
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["audio"]),
sample_rate=self.sample_rate,
num_channels=1,
)
await self.push_frame(frame)
elif msg.get("error_code"):
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(error_msg=f"Error: {msg['message']}")
else:
await self.push_error(error_msg=f"Unknown message type: {msg}")
audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
await self.append_to_audio_context(received_ctx_id, frame)
async def _keepalive_task_handler(self):
"""Send periodic keepalive messages to maintain WebSocket connection."""
KEEPALIVE_SLEEP = 3
KEEPALIVE_SLEEP = 10
while True:
await asyncio.sleep(KEEPALIVE_SLEEP)
try:
if self._websocket and self._websocket.state is State.OPEN:
keepalive_message = {"transcript": " "}
logger.trace("Sending keepalive message")
if self._context_id:
keepalive_message = {"transcript": " ", "context_id": self._context_id,}
logger.trace("Sending keepalive message")
else:
# It's possible to have a user interruption which clears the context
# without generating a new TTS response. In this case, we'll just send
# an empty message to keep the connection alive.
keepalive_message = {"transcript": " "}
logger.trace("Sending keepalive without context")
await self._websocket.send(json.dumps(keepalive_message))
except websockets.ConnectionClosed as e:
logger.warning(f"{self} keepalive error: {e}")
break
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by closing the current context."""
await super()._handle_interruption(frame, direction)
# Close the current context when interrupted without closing the websocket
if self._context_id and self._websocket:
try:
await self._websocket.send(
json.dumps({"context_id": self._context_id, "close_context": True, "transcript": ""})
)
except Exception as e:
logger.error(f"Error closing context on interruption: {e}")
self._context_id = None
self._started = False
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Async API websocket endpoint.
@@ -336,26 +495,35 @@ class AsyncAITTSService(InterruptibleTTSService):
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True
try:
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True
msg = self._build_msg(text=text, force=True)
if not self._context_id:
self._context_id = str(uuid.uuid4())
if not self.audio_context_available(self._context_id):
await self.create_audio_context(self._context_id)
try:
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
msg = self._build_msg(text=" ", context_id=self._context_id)
await self._get_websocket().send(msg)
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
else:
if self._websocket and self._context_id:
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
await self._get_websocket().send(msg)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
logger.error(f"{self} error sending message: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
self._started = False
return
yield None
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
logger.error(f"{self} exception: {e}")
class AsyncAIHttpTTSService(TTSService):
"""HTTP-based Async TTS service.
@@ -466,9 +634,9 @@ class AsyncAIHttpTTSService(TTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
first_byte_seen = False
try:
voice_config = {"mode": "id", "id": self._voice_id}
await self.start_ttfb_metrics()
payload = {
"model_id": self._model_name,
"transcript": text,
@@ -476,7 +644,6 @@ class AsyncAIHttpTTSService(TTSService):
"output_format": self._settings["output_format"],
"language": self._settings["language"],
}
yield TTSStartedFrame()
headers = {
"version": self._api_version,
"x-api-key": self._api_key,
@@ -484,26 +651,36 @@ class AsyncAIHttpTTSService(TTSService):
}
url = f"{self._base_url}/text_to_speech/streaming"
yield TTSStartedFrame()
await self.start_ttfb_metrics()
async with self._session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
await self.push_error(error_msg=f"Async API error: {error_text}")
raise Exception(f"Async API returned status {response.status}: {error_text}")
audio_data = await response.read()
# Read streaming bytes; stop TTFB on the *first* received chunk
buffer = bytearray()
async for chunk in response.content.iter_chunked(64 * 1024):
if not chunk:
continue
if not first_byte_seen:
first_byte_seen = True
await self.stop_ttfb_metrics()
await self.start_tts_usage_metrics(text)
await self.start_tts_usage_metrics(text)
buffer.extend(chunk)
audio_data = bytes(buffer)
frame = TTSAudioRawFrame(
yield TTSAudioRawFrame(
audio=audio_data,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
await self.stop_ttfb_metrics()
if not first_byte_seen:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()