diff --git a/examples/voice/voice-xai-http.py b/examples/voice/voice-xai-http.py new file mode 100644 index 000000000..6d0d54e79 --- /dev/null +++ b/examples/voice/voice-xai-http.py @@ -0,0 +1,128 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os + +import aiohttp +from dotenv import load_dotenv +from loguru import logger + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMRunFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.xai.llm import GrokLLMService +from pipecat.services.xai.tts import XAIHttpTTSService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + +# We use lambdas to defer transport parameter creation until the transport +# type is selected at runtime. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info(f"Starting bot") + + async with aiohttp.ClientSession() as session: + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = XAIHttpTTSService( + api_key=os.getenv("XAI_API_KEY"), + aiohttp_session=session, + settings=XAIHttpTTSService.Settings( + voice="eve", + ), + ) + + llm = GrokLLMService( + api_key=os.getenv("XAI_API_KEY"), + settings=GrokLLMService.Settings( + system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.", + ), + ) + + context = LLMContext() + user_aggregator, assistant_aggregator = LLMContextAggregatorPair( + context, + user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()), + ) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + user_aggregator, # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + assistant_aggregator, # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + context.add_message( + {"role": "developer", "content": "Please introduce yourself to the user."} + ) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/examples/voice/voice-xai.py b/examples/voice/voice-xai.py index 833f9b816..541d66c77 100644 --- a/examples/voice/voice-xai.py +++ b/examples/voice/voice-xai.py @@ -24,7 +24,7 @@ from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport from pipecat.services.xai.llm import GrokLLMService from pipecat.services.xai.stt import XAISTTService -from pipecat.services.xai.tts import XAIHttpTTSService +from pipecat.services.xai.tts import XAITTSService from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams @@ -55,10 +55,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): async with aiohttp.ClientSession() as session: stt = XAISTTService(api_key=os.environ["XAI_API_KEY"]) - tts = XAIHttpTTSService( + tts = XAITTSService( api_key=os.environ["XAI_API_KEY"], aiohttp_session=session, - settings=XAIHttpTTSService.Settings( + settings=XAITTSService.Settings( voice="eve", ), ) diff --git a/src/pipecat/services/xai/tts.py b/src/pipecat/services/xai/tts.py index 17f67cf9a..f99772baa 100644 --- a/src/pipecat/services/xai/tts.py +++ b/src/pipecat/services/xai/tts.py @@ -6,22 +6,48 @@ """xAI text-to-speech service implementation. -Uses xAI's HTTP TTS endpoint documented at: -https://docs.x.ai/developers/model-capabilities/audio/text-to-speech +Provides two TTS services against xAI's voice API: + +- :class:`XAIHttpTTSService` uses the batch HTTP endpoint at + ``https://api.x.ai/v1/tts``. +- :class:`XAITTSService` uses the streaming WebSocket endpoint at + ``wss://api.x.ai/v1/tts``. + +See https://docs.x.ai/developers/rest-api-reference/inference/voice. """ +import base64 +import json from collections.abc import AsyncGenerator from dataclasses import dataclass +from typing import Any +from urllib.parse import urlencode import aiohttp from loguru import logger -from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + StartFrame, + TTSAudioRawFrame, + TTSStoppedFrame, +) from pipecat.services.settings import TTSSettings -from pipecat.services.tts_service import TTSService +from pipecat.services.tts_service import InterruptibleTTSService, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts +try: + from websockets.asyncio.client import connect as websocket_connect + from websockets.protocol import State +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use XAITTSService, you need to `pip install pipecat-ai[xai]`.") + raise Exception(f"Missing module: {e}") + def language_to_xai_language(language: Language) -> str | None: """Convert a Language enum to xAI language code. @@ -214,3 +240,249 @@ class XAIHttpTTSService(TTSService): ) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}") + + +@dataclass +class XAIWebsocketTTSSettings(TTSSettings): + """Settings for XAITTSService (WebSocket streaming).""" + + pass + + +class XAITTSService(InterruptibleTTSService): + """xAI streaming text-to-speech service. + + Connects to xAI's WebSocket TTS endpoint and streams audio chunks back as + they are synthesized. Text can be sent incrementally via ``text.delta`` + messages and each utterance is terminated with ``text.done``. The server + responds with ``audio.delta`` chunks followed by an ``audio.done`` message. + + Audio parameters (voice, language, codec, sample rate, bit rate) are passed + as query string parameters on the WebSocket URL; changing any of them at + runtime reconnects the WebSocket. + """ + + Settings = XAIWebsocketTTSSettings + _settings: Settings + + def __init__( + self, + *, + api_key: str, + base_url: str = "wss://api.x.ai/v1/tts", + sample_rate: int | None = None, + codec: str = "pcm", + settings: Settings | None = None, + **kwargs, + ): + """Initialize the xAI WebSocket TTS service. + + Args: + api_key: xAI API key for authentication. + base_url: xAI TTS WebSocket endpoint. Defaults to + ``wss://api.x.ai/v1/tts``. + sample_rate: Output audio sample rate in Hz. If None, uses the + pipeline default. + codec: Output audio codec. One of ``pcm``, ``wav``, ``mulaw``, + ``alaw``. Defaults to ``pcm`` so emitted ``TTSAudioRawFrame`` + objects need no decoding downstream. + settings: Runtime-updatable settings. + **kwargs: Additional arguments passed to parent + ``InterruptibleTTSService``. + """ + default_settings = self.Settings( + model=None, + voice="eve", + language=Language.EN, + ) + + if settings is not None: + default_settings.apply_update(settings) + + super().__init__( + push_start_frame=True, + push_stop_frames=True, + sample_rate=sample_rate, + settings=default_settings, + **kwargs, + ) + + self._api_key = api_key + self._base_url = base_url + self._codec = codec + self._receive_task = None + + def can_generate_metrics(self) -> bool: + """Check if this service can generate processing metrics.""" + return True + + def language_to_service_language(self, language: Language) -> str | None: + """Convert a Language enum to xAI language format.""" + return language_to_xai_language(language) + + async def start(self, frame: StartFrame): + """Start the xAI WebSocket TTS service.""" + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + """Stop the xAI WebSocket TTS service.""" + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + """Cancel the xAI WebSocket TTS service.""" + await super().cancel(frame) + await self._disconnect() + + async def _connect(self): + await super()._connect() + + await self._connect_websocket() + + if self._websocket and not self._receive_task: + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) + + async def _disconnect(self): + await super()._disconnect() + + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + await self._disconnect_websocket() + + async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]: + """Apply a settings delta. Reconnects if any URL-baked field changes.""" + changed = await super()._update_settings(delta) + + if changed: + await self._disconnect() + await self._connect() + + return changed + + def _build_url(self) -> str: + language = self._settings.language + if isinstance(language, Language): + language_value = language_to_xai_language(language) or language.value + else: + language_value = str(language) if language is not None else "auto" + + params: dict[str, Any] = { + "voice": self._settings.voice, + "language": language_value, + "codec": self._codec, + "sample_rate": self.sample_rate, + } + return f"{self._base_url}?{urlencode(params)}" + + async def _connect_websocket(self): + try: + if self._websocket and self._websocket.state is State.OPEN: + return + + logger.debug("Connecting to xAI TTS") + + url = self._build_url() + headers = {"Authorization": f"Bearer {self._api_key}"} + self._websocket = await websocket_connect(url, additional_headers=headers) + + await self._call_event_handler("on_connected") + except Exception as e: + await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) + self._websocket = None + await self._call_event_handler("on_connection_error", f"{e}") + + async def _disconnect_websocket(self): + try: + await self.stop_all_metrics() + + if self._websocket: + logger.debug("Disconnecting from xAI TTS") + await self._websocket.close() + except Exception as e: + await self.push_error(error_msg=f"Error disconnecting from xAI TTS: {e}", exception=e) + finally: + self._websocket = None + await self._call_event_handler("on_disconnected") + + def _get_websocket(self): + if self._websocket: + return self._websocket + raise Exception("Websocket not connected") + + async def flush_audio(self, context_id: str | None = None): + """Signal end-of-utterance so xAI begins synthesizing what it has buffered.""" + if not self._websocket or self._websocket.state is State.CLOSED: + return + await self._get_websocket().send(json.dumps({"type": "text.done"})) + + async def _receive_messages(self): + async for message in self._get_websocket(): + if isinstance(message, bytes): + logger.warning(f"{self}: unexpected binary frame from xAI TTS") + continue + try: + msg = json.loads(message) + except json.JSONDecodeError: + logger.error(f"{self}: invalid JSON message: {message}") + continue + + msg_type = msg.get("type") + context_id = self.get_active_audio_context_id() + + if msg_type == "audio.delta": + audio_b64 = msg.get("delta") + if not audio_b64: + continue + audio = base64.b64decode(audio_b64) + await self.stop_ttfb_metrics() + if context_id: + frame = TTSAudioRawFrame( + audio=audio, + sample_rate=self.sample_rate, + num_channels=1, + context_id=context_id, + ) + await self.append_to_audio_context(context_id, frame) + elif msg_type == "audio.done": + await self.stop_all_metrics() + if context_id: + await self.append_to_audio_context( + context_id, TTSStoppedFrame(context_id=context_id) + ) + await self.remove_audio_context(context_id) + elif msg_type == "error": + await self.stop_all_metrics() + error_detail = msg.get("message") or msg.get("error") or str(msg) + if context_id: + await self.append_to_audio_context( + context_id, TTSStoppedFrame(context_id=context_id) + ) + await self.remove_audio_context(context_id) + await self.push_error(error_msg=f"xAI TTS error: {error_detail}") + else: + logger.debug(f"{self}: unhandled xAI message type: {msg_type}") + + @traced_tts + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]: + """Generate TTS audio from text using xAI's streaming WebSocket API.""" + logger.debug(f"{self}: Generating TTS [{text}]") + + try: + if not self._websocket or self._websocket.state is State.CLOSED: + await self._connect() + + try: + await self._get_websocket().send(json.dumps({"type": "text.delta", "delta": text})) + await self.start_tts_usage_metrics(text) + except Exception as e: + yield ErrorFrame(error=f"Unknown error occurred: {e}") + yield TTSStoppedFrame(context_id=context_id) + await self._disconnect() + await self._connect() + return + yield None + except Exception as e: + yield ErrorFrame(error=f"Unknown error occurred: {e}") diff --git a/tests/test_xai_tts.py b/tests/test_xai_tts.py index aab984567..98ac48fd4 100644 --- a/tests/test_xai_tts.py +++ b/tests/test_xai_tts.py @@ -4,14 +4,19 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""Tests for XAIHttpTTSService.""" +"""Tests for XAIHttpTTSService and XAITTSService.""" import asyncio +import base64 +import json import unittest +from urllib.parse import parse_qs, urlparse import aiohttp import pytest +import websockets from aiohttp import web +from websockets.asyncio.server import serve from pipecat.frames.frames import ( AggregatedTextFrame, @@ -21,7 +26,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, TTSTextFrame, ) -from pipecat.services.xai.tts import XAIHttpTTSService +from pipecat.services.xai.tts import XAIHttpTTSService, XAITTSService from pipecat.tests.utils import run_test @@ -87,5 +92,87 @@ async def test_run_xai_tts_success(aiohttp_client): } +@pytest.mark.asyncio +async def test_run_xai_websocket_tts_success(): + """xAI WS TTS should send text.delta+text.done and emit frames from audio.delta+audio.done.""" + + captured: dict = { + "request_path": None, + "auth_header": None, + "messages": [], + } + + audio_bytes = b"\x00\x01\x02\x03" * 1024 + + async def handler(ws): + request = ws.request + captured["request_path"] = request.path + captured["auth_header"] = request.headers.get("Authorization") + + try: + async for raw in ws: + msg = json.loads(raw) + captured["messages"].append(msg) + if msg.get("type") == "text.done": + await ws.send( + json.dumps( + { + "type": "audio.delta", + "delta": base64.b64encode(audio_bytes).decode("ascii"), + } + ) + ) + await ws.send(json.dumps({"type": "audio.done", "trace_id": "test-trace"})) + except websockets.ConnectionClosed: + pass + + async with serve(handler, "127.0.0.1", 0) as server: + host, port = next(iter(server.sockets)).getsockname()[:2] + base_url = f"ws://{host}:{port}/v1/tts" + + tts_service = XAITTSService( + api_key="test-key", + base_url=base_url, + sample_rate=24000, + ) + + down_frames, _ = await run_test( + tts_service, + frames_to_send=[TTSSpeakFrame(text="Hello from xAI."), _SleepAfterSpeak(0.3)], + ) + + frame_types = [type(frame) for frame in down_frames] + assert TTSStartedFrame in frame_types + assert TTSAudioRawFrame in frame_types + assert TTSStoppedFrame in frame_types + + audio_frames = [frame for frame in down_frames if isinstance(frame, TTSAudioRawFrame)] + assert audio_frames + assert all(frame.sample_rate == 24000 for frame in audio_frames) + assert all(frame.num_channels == 1 for frame in audio_frames) + assert b"".join(f.audio for f in audio_frames) == audio_bytes + + assert captured["auth_header"] == "Bearer test-key" + parsed = urlparse(captured["request_path"]) + query = parse_qs(parsed.query) + assert query["voice"] == ["eve"] + assert query["language"] == ["en"] + assert query["codec"] == ["pcm"] + assert query["sample_rate"] == ["24000"] + + types_sent = [m.get("type") for m in captured["messages"]] + assert "text.delta" in types_sent + assert "text.done" in types_sent + delta_msg = next(m for m in captured["messages"] if m.get("type") == "text.delta") + assert delta_msg["delta"] == "Hello from xAI." + + +# Small helper imported lazily to avoid circular import in fixture-lite tests. +def _SleepAfterSpeak(duration: float): + from pipecat.tests.utils import SleepFrame + + return SleepFrame(sleep=duration) + + if __name__ == "__main__": unittest.main()