Fix TTFB metric and add multi-context WebSocket support for Async TTS
This commit is contained in:
@@ -6,6 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
<!-- towncrier release notes start -->
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
|
||||
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
|
||||
|
||||
## [0.0.99] - 2026-01-13
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator, Optional
|
||||
from typing import AsyncGenerator, Optional, Dict
|
||||
|
||||
import uuid
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
@@ -27,7 +28,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
|
||||
from pipecat.services.tts_service import WebsocketTTSService, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -72,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]:
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
|
||||
|
||||
class AsyncAITTSService(InterruptibleTTSService):
|
||||
class AsyncAITTSService(WebsocketTTSService):
|
||||
"""Async TTS service with WebSocket streaming.
|
||||
|
||||
Provides text-to-speech using Async's streaming WebSocket API.
|
||||
@@ -126,6 +127,10 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._contexts: Dict[str, asyncio.Queue] = {}
|
||||
self._audio_context_task = None
|
||||
self._context_id = None
|
||||
|
||||
params = params or AsyncAITTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
@@ -148,6 +153,148 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
self._started = False
|
||||
|
||||
async def create_audio_context(self, context_id: str):
|
||||
"""Create a new audio context for grouping related audio.
|
||||
|
||||
Args:
|
||||
context_id: Unique identifier for the audio context.
|
||||
"""
|
||||
await self._contexts_queue.put(context_id)
|
||||
self._contexts[context_id] = asyncio.Queue()
|
||||
logger.trace(f"{self} created audio context {context_id}")
|
||||
|
||||
async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame):
|
||||
"""Append audio to an existing context.
|
||||
|
||||
Args:
|
||||
context_id: The context to append audio to.
|
||||
frame: The audio frame to append.
|
||||
"""
|
||||
if self.audio_context_available(context_id):
|
||||
logger.trace(f"{self} appending audio {frame} to audio context {context_id}")
|
||||
await self._contexts[context_id].put(frame)
|
||||
else:
|
||||
logger.warning(f"{self} unable to append audio to context {context_id}")
|
||||
|
||||
async def remove_audio_context(self, context_id: str):
|
||||
"""Remove an existing audio context.
|
||||
|
||||
Args:
|
||||
context_id: The context to remove.
|
||||
"""
|
||||
if self.audio_context_available(context_id):
|
||||
# We just mark the audio context for deletion by appending
|
||||
# None. Once we reach None while handling audio we know we can
|
||||
# safely remove the context.
|
||||
logger.trace(f"{self} marking audio context {context_id} for deletion")
|
||||
await self._contexts[context_id].put(None)
|
||||
else:
|
||||
logger.warning(f"{self} unable to remove context {context_id}")
|
||||
|
||||
def audio_context_available(self, context_id: str) -> bool:
|
||||
"""Check whether the given audio context is registered.
|
||||
|
||||
Args:
|
||||
context_id: The context ID to check.
|
||||
|
||||
Returns:
|
||||
True if the context exists and is available.
|
||||
"""
|
||||
return context_id in self._contexts
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Async TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._create_audio_context_task()
|
||||
self._settings["output_format"]["sample_rate"] = self.sample_rate
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Async TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
if self._audio_context_task:
|
||||
# Indicate no more audio contexts are available. this will end the
|
||||
# task cleanly after all contexts have been processed.
|
||||
await self._contexts_queue.put(None)
|
||||
await self._audio_context_task
|
||||
self._audio_context_task = None
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Async TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._stop_audio_context_task()
|
||||
await self._disconnect()
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self._stop_audio_context_task()
|
||||
self._create_audio_context_task()
|
||||
|
||||
def _create_audio_context_task(self):
|
||||
if not self._audio_context_task:
|
||||
self._contexts_queue = asyncio.Queue()
|
||||
self._contexts: Dict[str, asyncio.Queue] = {}
|
||||
self._audio_context_task = self.create_task(self._audio_context_task_handler())
|
||||
|
||||
async def _stop_audio_context_task(self):
|
||||
if self._audio_context_task:
|
||||
await self.cancel_task(self._audio_context_task)
|
||||
self._audio_context_task = None
|
||||
|
||||
async def _audio_context_task_handler(self):
|
||||
"""In this task we process audio contexts in order."""
|
||||
running = True
|
||||
while running:
|
||||
context_id = await self._contexts_queue.get()
|
||||
|
||||
if context_id:
|
||||
# Process the audio context until the context doesn't have more
|
||||
# audio available (i.e. we find None).
|
||||
await self._handle_audio_context(context_id)
|
||||
|
||||
# We just finished processing the context, so we can safely remove it.
|
||||
del self._contexts[context_id]
|
||||
|
||||
# Append some silence between sentences.
|
||||
silence = b"\x00" * self.sample_rate
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=silence, sample_rate=self.sample_rate, num_channels=1
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
running = False
|
||||
|
||||
self._contexts_queue.task_done()
|
||||
|
||||
async def _handle_audio_context(self, context_id: str):
|
||||
# If we don't receive any audio during this time, we consider the context finished.
|
||||
AUDIO_CONTEXT_TIMEOUT = 3.0
|
||||
queue = self._contexts[context_id]
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT)
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
running = frame is not None
|
||||
except asyncio.TimeoutError:
|
||||
# We didn't get audio, so let's consider this context finished.
|
||||
logger.trace(f"{self} time out on audio context {context_id}")
|
||||
break
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -168,38 +315,10 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return language_to_async_language(language)
|
||||
|
||||
def _build_msg(self, text: str = "", force: bool = False) -> str:
|
||||
msg = {"transcript": text, "force": force}
|
||||
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
|
||||
msg = {"transcript": text, "context_id": context_id, "force": force}
|
||||
return json.dumps(msg)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Async TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._settings["output_format"]["sample_rate"] = self.sample_rate
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Async TTS service.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Async TTS service.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
@@ -253,11 +372,16 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Async")
|
||||
# Close all contexts and the socket
|
||||
if self._context_id:
|
||||
await self._websocket.send(json.dumps({"terminate": True}))
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from Async")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -268,10 +392,10 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio."""
|
||||
if not self._websocket:
|
||||
if not self._context_id or not self._websocket:
|
||||
return
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
msg = self._build_msg(text=" ", force=True)
|
||||
msg = self._build_msg(text=" ", context_id=self._context_id, force=True)
|
||||
await self._websocket.send(msg)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -291,35 +415,70 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
if not msg:
|
||||
continue
|
||||
|
||||
elif msg.get("audio"):
|
||||
received_ctx_id = msg.get("context_id")
|
||||
# Handle final messages first, regardless of context availability
|
||||
# At the moment, this message is received AFTER the close_context message is
|
||||
# sent, so it doesn't serve any functional purpose. For now, we'll just log it.
|
||||
if msg.get("final") is True:
|
||||
logger.trace(f"Received final message for context {received_ctx_id}")
|
||||
continue
|
||||
|
||||
# Check if this message belongs to the current context.
|
||||
if not self.audio_context_available(received_ctx_id):
|
||||
if self._context_id == received_ctx_id:
|
||||
logger.debug(
|
||||
f"Received a delayed message, recreating the context: {self._context_id}"
|
||||
)
|
||||
await self.create_audio_context(self._context_id)
|
||||
else:
|
||||
# This can happen if a message is received _after_ we have closed a context
|
||||
# due to user interruption but _before_ the `isFinal` message for the context
|
||||
# is received.
|
||||
logger.debug(f"Ignoring message from unavailable context: {received_ctx_id}")
|
||||
continue
|
||||
|
||||
if msg.get("audio"):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["audio"]),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
elif msg.get("error_code"):
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
else:
|
||||
await self.push_error(error_msg=f"Unknown message type: {msg}")
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
|
||||
await self.append_to_audio_context(received_ctx_id, frame)
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic keepalive messages to maintain WebSocket connection."""
|
||||
KEEPALIVE_SLEEP = 3
|
||||
KEEPALIVE_SLEEP = 10
|
||||
while True:
|
||||
await asyncio.sleep(KEEPALIVE_SLEEP)
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
keepalive_message = {"transcript": " "}
|
||||
logger.trace("Sending keepalive message")
|
||||
if self._context_id:
|
||||
keepalive_message = {"transcript": " ", "context_id": self._context_id,}
|
||||
logger.trace("Sending keepalive message")
|
||||
else:
|
||||
# It's possible to have a user interruption which clears the context
|
||||
# without generating a new TTS response. In this case, we'll just send
|
||||
# an empty message to keep the connection alive.
|
||||
keepalive_message = {"transcript": " "}
|
||||
logger.trace("Sending keepalive without context")
|
||||
await self._websocket.send(json.dumps(keepalive_message))
|
||||
except websockets.ConnectionClosed as e:
|
||||
logger.warning(f"{self} keepalive error: {e}")
|
||||
break
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by closing the current context."""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
# Close the current context when interrupted without closing the websocket
|
||||
if self._context_id and self._websocket:
|
||||
try:
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": self._context_id, "close_context": True, "transcript": ""})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing context on interruption: {e}")
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Async API websocket endpoint.
|
||||
@@ -336,26 +495,35 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if not self._started:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
try:
|
||||
if not self._started:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
|
||||
msg = self._build_msg(text=text, force=True)
|
||||
if not self._context_id:
|
||||
self._context_id = str(uuid.uuid4())
|
||||
if not self.audio_context_available(self._context_id):
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
try:
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
msg = self._build_msg(text=" ", context_id=self._context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
else:
|
||||
if self._websocket and self._context_id:
|
||||
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
self._started = False
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
class AsyncAIHttpTTSService(TTSService):
|
||||
"""HTTP-based Async TTS service.
|
||||
@@ -466,9 +634,9 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
first_byte_seen = False
|
||||
try:
|
||||
voice_config = {"mode": "id", "id": self._voice_id}
|
||||
await self.start_ttfb_metrics()
|
||||
payload = {
|
||||
"model_id": self._model_name,
|
||||
"transcript": text,
|
||||
@@ -476,7 +644,6 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
"output_format": self._settings["output_format"],
|
||||
"language": self._settings["language"],
|
||||
}
|
||||
yield TTSStartedFrame()
|
||||
headers = {
|
||||
"version": self._api_version,
|
||||
"x-api-key": self._api_key,
|
||||
@@ -484,26 +651,36 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
}
|
||||
url = f"{self._base_url}/text_to_speech/streaming"
|
||||
|
||||
yield TTSStartedFrame()
|
||||
await self.start_ttfb_metrics()
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
await self.push_error(error_msg=f"Async API error: {error_text}")
|
||||
raise Exception(f"Async API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
# Read streaming bytes; stop TTFB on the *first* received chunk
|
||||
buffer = bytearray()
|
||||
async for chunk in response.content.iter_chunked(64 * 1024):
|
||||
if not chunk:
|
||||
continue
|
||||
if not first_byte_seen:
|
||||
first_byte_seen = True
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
buffer.extend(chunk)
|
||||
audio_data = bytes(buffer)
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
yield TTSAudioRawFrame(
|
||||
audio=audio_data,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
if not first_byte_seen:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
Reference in New Issue
Block a user