Add XAITTSService for xAI streaming WebSocket TTS

Adds XAITTSService in the existing xai/tts.py module, alongside the
existing XAIHttpTTSService. Connects to xAI's streaming endpoint at
wss://api.x.ai/v1/tts, streams text.delta chunks up and base64 audio.delta
chunks down on the same connection so audio starts flowing before the full
utterance is synthesized.

Extends InterruptibleTTSService since xAI's protocol is strictly sequential
per connection and exposes neither a cancel verb nor a context ID — the
only way to stop an in-flight utterance is to tear down the WebSocket,
which is exactly what InterruptibleTTSService does on interruption when
the bot is speaking.

Voice, language, codec, and sample_rate are passed as query-string params
at connect time; runtime setting changes reconnect the socket. Defaults to
raw PCM so emitted TTSAudioRawFrame objects need no decoding downstream.

Splits the existing example into voice-xai.py (WebSocket) and
voice-xai-http.py (batch HTTP) so each variant has its own entry point.
Promotes the xai extra to depend on pipecat-ai[websockets-base] since the
new service imports the websockets library.
This commit is contained in:
Mark Backman
2026-04-20 19:30:04 -04:00
parent 93393ea91c
commit d8f5c0be71
4 changed files with 496 additions and 9 deletions

View File

@@ -0,0 +1,128 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.xai.llm import GrokLLMService
from pipecat.services.xai.tts import XAIHttpTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# We use lambdas to defer transport parameter creation until the transport
# type is selected at runtime.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
async with aiohttp.ClientSession() as session:
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = XAIHttpTTSService(
api_key=os.getenv("XAI_API_KEY"),
aiohttp_session=session,
settings=XAIHttpTTSService.Settings(
voice="eve",
),
)
llm = GrokLLMService(
api_key=os.getenv("XAI_API_KEY"),
settings=GrokLLMService.Settings(
system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.",
),
)
context = LLMContext()
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
user_aggregator, # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
assistant_aggregator, # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
context.add_message(
{"role": "developer", "content": "Please introduce yourself to the user."}
)
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -24,7 +24,7 @@ from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.xai.llm import GrokLLMService
from pipecat.services.xai.stt import XAISTTService
from pipecat.services.xai.tts import XAIHttpTTSService
from pipecat.services.xai.tts import XAITTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -55,10 +55,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
async with aiohttp.ClientSession() as session:
stt = XAISTTService(api_key=os.environ["XAI_API_KEY"])
tts = XAIHttpTTSService(
tts = XAITTSService(
api_key=os.environ["XAI_API_KEY"],
aiohttp_session=session,
settings=XAIHttpTTSService.Settings(
settings=XAITTSService.Settings(
voice="eve",
),
)

View File

@@ -6,22 +6,48 @@
"""xAI text-to-speech service implementation.
Uses xAI's HTTP TTS endpoint documented at:
https://docs.x.ai/developers/model-capabilities/audio/text-to-speech
Provides two TTS services against xAI's voice API:
- :class:`XAIHttpTTSService` uses the batch HTTP endpoint at
``https://api.x.ai/v1/tts``.
- :class:`XAITTSService` uses the streaming WebSocket endpoint at
``wss://api.x.ai/v1/tts``.
See https://docs.x.ai/developers/rest-api-reference/inference/voice.
"""
import base64
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlencode
import aiohttp
from loguru import logger
from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings
from pipecat.services.tts_service import TTSService
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
try:
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use XAITTSService, you need to `pip install pipecat-ai[xai]`.")
raise Exception(f"Missing module: {e}")
def language_to_xai_language(language: Language) -> str | None:
"""Convert a Language enum to xAI language code.
@@ -214,3 +240,249 @@ class XAIHttpTTSService(TTSService):
)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
@dataclass
class XAIWebsocketTTSSettings(TTSSettings):
"""Settings for XAITTSService (WebSocket streaming)."""
pass
class XAITTSService(InterruptibleTTSService):
"""xAI streaming text-to-speech service.
Connects to xAI's WebSocket TTS endpoint and streams audio chunks back as
they are synthesized. Text can be sent incrementally via ``text.delta``
messages and each utterance is terminated with ``text.done``. The server
responds with ``audio.delta`` chunks followed by an ``audio.done`` message.
Audio parameters (voice, language, codec, sample rate, bit rate) are passed
as query string parameters on the WebSocket URL; changing any of them at
runtime reconnects the WebSocket.
"""
Settings = XAIWebsocketTTSSettings
_settings: Settings
def __init__(
self,
*,
api_key: str,
base_url: str = "wss://api.x.ai/v1/tts",
sample_rate: int | None = None,
codec: str = "pcm",
settings: Settings | None = None,
**kwargs,
):
"""Initialize the xAI WebSocket TTS service.
Args:
api_key: xAI API key for authentication.
base_url: xAI TTS WebSocket endpoint. Defaults to
``wss://api.x.ai/v1/tts``.
sample_rate: Output audio sample rate in Hz. If None, uses the
pipeline default.
codec: Output audio codec. One of ``pcm``, ``wav``, ``mulaw``,
``alaw``. Defaults to ``pcm`` so emitted ``TTSAudioRawFrame``
objects need no decoding downstream.
settings: Runtime-updatable settings.
**kwargs: Additional arguments passed to parent
``InterruptibleTTSService``.
"""
default_settings = self.Settings(
model=None,
voice="eve",
language=Language.EN,
)
if settings is not None:
default_settings.apply_update(settings)
super().__init__(
push_start_frame=True,
push_stop_frames=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
)
self._api_key = api_key
self._base_url = base_url
self._codec = codec
self._receive_task = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics."""
return True
def language_to_service_language(self, language: Language) -> str | None:
"""Convert a Language enum to xAI language format."""
return language_to_xai_language(language)
async def start(self, frame: StartFrame):
"""Start the xAI WebSocket TTS service."""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the xAI WebSocket TTS service."""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the xAI WebSocket TTS service."""
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
"""Apply a settings delta. Reconnects if any URL-baked field changes."""
changed = await super()._update_settings(delta)
if changed:
await self._disconnect()
await self._connect()
return changed
def _build_url(self) -> str:
language = self._settings.language
if isinstance(language, Language):
language_value = language_to_xai_language(language) or language.value
else:
language_value = str(language) if language is not None else "auto"
params: dict[str, Any] = {
"voice": self._settings.voice,
"language": language_value,
"codec": self._codec,
"sample_rate": self.sample_rate,
}
return f"{self._base_url}?{urlencode(params)}"
async def _connect_websocket(self):
try:
if self._websocket and self._websocket.state is State.OPEN:
return
logger.debug("Connecting to xAI TTS")
url = self._build_url()
headers = {"Authorization": f"Bearer {self._api_key}"}
self._websocket = await websocket_connect(url, additional_headers=headers)
await self._call_event_handler("on_connected")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from xAI TTS")
await self._websocket.close()
except Exception as e:
await self.push_error(error_msg=f"Error disconnecting from xAI TTS: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self, context_id: str | None = None):
"""Signal end-of-utterance so xAI begins synthesizing what it has buffered."""
if not self._websocket or self._websocket.state is State.CLOSED:
return
await self._get_websocket().send(json.dumps({"type": "text.done"}))
async def _receive_messages(self):
async for message in self._get_websocket():
if isinstance(message, bytes):
logger.warning(f"{self}: unexpected binary frame from xAI TTS")
continue
try:
msg = json.loads(message)
except json.JSONDecodeError:
logger.error(f"{self}: invalid JSON message: {message}")
continue
msg_type = msg.get("type")
context_id = self.get_active_audio_context_id()
if msg_type == "audio.delta":
audio_b64 = msg.get("delta")
if not audio_b64:
continue
audio = base64.b64decode(audio_b64)
await self.stop_ttfb_metrics()
if context_id:
frame = TTSAudioRawFrame(
audio=audio,
sample_rate=self.sample_rate,
num_channels=1,
context_id=context_id,
)
await self.append_to_audio_context(context_id, frame)
elif msg_type == "audio.done":
await self.stop_all_metrics()
if context_id:
await self.append_to_audio_context(
context_id, TTSStoppedFrame(context_id=context_id)
)
await self.remove_audio_context(context_id)
elif msg_type == "error":
await self.stop_all_metrics()
error_detail = msg.get("message") or msg.get("error") or str(msg)
if context_id:
await self.append_to_audio_context(
context_id, TTSStoppedFrame(context_id=context_id)
)
await self.remove_audio_context(context_id)
await self.push_error(error_msg=f"xAI TTS error: {error_detail}")
else:
logger.debug(f"{self}: unhandled xAI message type: {msg_type}")
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
"""Generate TTS audio from text using xAI's streaming WebSocket API."""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
try:
await self._get_websocket().send(json.dumps({"type": "text.delta", "delta": text}))
await self.start_tts_usage_metrics(text)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame(context_id=context_id)
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -4,14 +4,19 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for XAIHttpTTSService."""
"""Tests for XAIHttpTTSService and XAITTSService."""
import asyncio
import base64
import json
import unittest
from urllib.parse import parse_qs, urlparse
import aiohttp
import pytest
import websockets
from aiohttp import web
from websockets.asyncio.server import serve
from pipecat.frames.frames import (
AggregatedTextFrame,
@@ -21,7 +26,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
TTSTextFrame,
)
from pipecat.services.xai.tts import XAIHttpTTSService
from pipecat.services.xai.tts import XAIHttpTTSService, XAITTSService
from pipecat.tests.utils import run_test
@@ -87,5 +92,87 @@ async def test_run_xai_tts_success(aiohttp_client):
}
@pytest.mark.asyncio
async def test_run_xai_websocket_tts_success():
"""xAI WS TTS should send text.delta+text.done and emit frames from audio.delta+audio.done."""
captured: dict = {
"request_path": None,
"auth_header": None,
"messages": [],
}
audio_bytes = b"\x00\x01\x02\x03" * 1024
async def handler(ws):
request = ws.request
captured["request_path"] = request.path
captured["auth_header"] = request.headers.get("Authorization")
try:
async for raw in ws:
msg = json.loads(raw)
captured["messages"].append(msg)
if msg.get("type") == "text.done":
await ws.send(
json.dumps(
{
"type": "audio.delta",
"delta": base64.b64encode(audio_bytes).decode("ascii"),
}
)
)
await ws.send(json.dumps({"type": "audio.done", "trace_id": "test-trace"}))
except websockets.ConnectionClosed:
pass
async with serve(handler, "127.0.0.1", 0) as server:
host, port = next(iter(server.sockets)).getsockname()[:2]
base_url = f"ws://{host}:{port}/v1/tts"
tts_service = XAITTSService(
api_key="test-key",
base_url=base_url,
sample_rate=24000,
)
down_frames, _ = await run_test(
tts_service,
frames_to_send=[TTSSpeakFrame(text="Hello from xAI."), _SleepAfterSpeak(0.3)],
)
frame_types = [type(frame) for frame in down_frames]
assert TTSStartedFrame in frame_types
assert TTSAudioRawFrame in frame_types
assert TTSStoppedFrame in frame_types
audio_frames = [frame for frame in down_frames if isinstance(frame, TTSAudioRawFrame)]
assert audio_frames
assert all(frame.sample_rate == 24000 for frame in audio_frames)
assert all(frame.num_channels == 1 for frame in audio_frames)
assert b"".join(f.audio for f in audio_frames) == audio_bytes
assert captured["auth_header"] == "Bearer test-key"
parsed = urlparse(captured["request_path"])
query = parse_qs(parsed.query)
assert query["voice"] == ["eve"]
assert query["language"] == ["en"]
assert query["codec"] == ["pcm"]
assert query["sample_rate"] == ["24000"]
types_sent = [m.get("type") for m in captured["messages"]]
assert "text.delta" in types_sent
assert "text.done" in types_sent
delta_msg = next(m for m in captured["messages"] if m.get("type") == "text.delta")
assert delta_msg["delta"] == "Hello from xAI."
# Small helper imported lazily to avoid circular import in fixture-lite tests.
def _SleepAfterSpeak(duration: float):
from pipecat.tests.utils import SleepFrame
return SleepFrame(sleep=duration)
if __name__ == "__main__":
unittest.main()