Address PR review: rename to XAIHttpTTSService, add language map, clean up API

- Rename XAITTSService → XAIHttpTTSService and XAITTSSettings → XAIHttpTTSSettings
- Add language_to_xai_language() with explicit LANGUAGE_MAP using resolve_language()
- Remove deprecated InputParams, params, voice, language init params
- Remove XAI_DEFAULT_SAMPLE_RATE and XAI_PCM_CODEC constants; add encoding param
- Set sample_rate=None default (picked up from PipelineParams or user)
- Use Language.EN enum instead of string "en" for default language
- Add changelog/4031.added.md
- Add 07e-interruptible-xai.py foundational example
- Update 14g-function-calling-grok.py to use XAIHttpTTSService
- Register 07e in run-release-evals.py
This commit is contained in:
Nicholas Zhao
2026-03-24 22:21:17 -07:00
committed by Mark Backman
parent 02b97035f8
commit bbd14de9c5
5 changed files with 276 additions and 137 deletions

1
changelog/4031.added.md Normal file
View File

@@ -0,0 +1 @@
- Added `XAIHttpTTSService` for text-to-speech using xAI's HTTP TTS API.

View File

@@ -0,0 +1,129 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.grok.llm import GrokLLMService
from pipecat.services.xai.tts import XAIHttpTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# We use lambdas to defer transport parameter creation until the transport
# type is selected at runtime.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
async with aiohttp.ClientSession() as session:
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = XAIHttpTTSService(
api_key=os.getenv("GROK_API_KEY"),
aiohttp_session=session,
settings=XAIHttpTTSService.Settings(
voice="eve",
),
)
llm = GrokLLMService(
api_key=os.getenv("GROK_API_KEY"),
settings=GrokLLMService.Settings(
system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.",
),
)
context = LLMContext()
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
user_aggregator, # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
assistant_aggregator, # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
audio_out_sample_rate=8000,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
context.add_message(
{"role": "user", "content": "Please introduce yourself to the user."}
)
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -7,6 +7,7 @@
import os
import aiohttp
from dotenv import load_dotenv
from loguru import logger
@@ -24,10 +25,10 @@ from pipecat.processors.aggregators.llm_response_universal import (
)
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.grok.llm import GrokLLMService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.xai.tts import XAIHttpTTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -60,83 +61,88 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
async with aiohttp.ClientSession() as session:
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
settings=CartesiaTTSService.Settings(
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
),
)
tts = XAIHttpTTSService(
api_key=os.getenv("GROK_API_KEY"),
aiohttp_session=session,
settings=XAIHttpTTSService.Settings(
voice="eve",
),
)
llm = GrokLLMService(
api_key=os.getenv("GROK_API_KEY"),
settings=GrokLLMService.Settings(
system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.",
),
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_function("get_current_weather", fetch_weather_from_api)
llm = GrokLLMService(
api_key=os.getenv("GROK_API_KEY"),
settings=GrokLLMService.Settings(
system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.",
),
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_function("get_current_weather", fetch_weather_from_api)
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["location", "format"],
)
tools = ToolsSchema(standard_tools=[weather_function])
context = LLMContext(tools=tools)
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
)
required=["location", "format"],
)
tools = ToolsSchema(standard_tools=[weather_function])
context = LLMContext(tools=tools)
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
)
pipeline = Pipeline(
[
transport.input(),
stt,
user_aggregator,
llm,
tts,
transport.output(),
assistant_aggregator,
]
)
pipeline = Pipeline(
[
transport.input(),
stt,
user_aggregator,
llm,
tts,
transport.output(),
assistant_aggregator,
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
context.add_message(
{"role": "user", "content": "Please introduce yourself to the user."}
)
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
await runner.run(task)
async def bot(runner_args: RunnerArguments):

View File

@@ -15,57 +15,77 @@ from typing import AsyncGenerator, Optional
import aiohttp
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame
from pipecat.services.settings import TTSSettings
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
def language_to_xai_language(language: Language) -> Optional[str]:
"""Convert a Language enum to xAI language code.
Args:
language: The Language enum value to convert.
Returns:
The corresponding xAI language code, or None if not supported.
"""
LANGUAGE_MAP = {
Language.AR: "ar-EG",
Language.AR_EG: "ar-EG",
Language.AR_SA: "ar-SA",
Language.AR_AE: "ar-AE",
Language.BN: "bn",
Language.DE: "de",
Language.EN: "en",
Language.ES: "es-ES",
Language.ES_ES: "es-ES",
Language.ES_MX: "es-MX",
Language.FR: "fr",
Language.HI: "hi",
Language.ID: "id",
Language.IT: "it",
Language.JA: "ja",
Language.KO: "ko",
Language.PT: "pt-PT",
Language.PT_BR: "pt-BR",
Language.PT_PT: "pt-PT",
Language.RU: "ru",
Language.TR: "tr",
Language.VI: "vi",
Language.ZH: "zh",
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
@dataclass
class XAITTSSettings(TTSSettings):
"""Settings for XAITTSService."""
class XAIHttpTTSSettings(TTSSettings):
"""Settings for XAIHttpTTSService."""
pass
class XAITTSService(TTSService):
class XAIHttpTTSService(TTSService):
"""xAI HTTP text-to-speech service.
The service requests raw PCM audio so emitted ``TTSAudioRawFrame`` objects
match Pipecat's downstream expectations without extra decoding.
"""
Settings = XAITTSSettings
Settings = XAIHttpTTSSettings
_settings: Settings
XAI_DEFAULT_SAMPLE_RATE = 24000
XAI_PCM_CODEC = "pcm"
class InputParams(BaseModel):
"""Input parameters for xAI TTS configuration.
.. deprecated:: 0.0.105
Use ``settings=XAITTSService.Settings(...)`` instead.
Parameters:
language: Language for speech synthesis.
"""
language: Optional[Language] = None
def __init__(
self,
*,
api_key: str,
base_url: str = "https://api.x.ai/v1/tts",
voice: Optional[str] = None,
language: Optional[str | Language] = None,
sample_rate: Optional[int] = None,
encoding: Optional[str] = "pcm",
aiohttp_session: Optional[aiohttp.ClientSession] = None,
params: Optional[InputParams] = None,
settings: Optional[Settings] = None,
**kwargs,
):
@@ -74,54 +94,25 @@ class XAITTSService(TTSService):
Args:
api_key: xAI API key for authentication.
base_url: xAI TTS endpoint. Defaults to ``https://api.x.ai/v1/tts``.
voice: Voice identifier. Defaults to ``"eve"``.
.. deprecated:: 0.0.105
Use ``settings=XAITTSService.Settings(voice=...)`` instead.
language: BCP-47 or base language code (for example ``"en"`` or ``"pt-BR"``).
Defaults to ``"en"``.
.. deprecated:: 0.0.105
Use ``settings=XAITTSService.Settings(language=...)`` instead.
sample_rate: Output sample rate for PCM audio. Defaults to 24000 Hz.
sample_rate: Audio sample rate. If None, uses default.
encoding: Output encoding format. Defaults to "pcm".
aiohttp_session: Optional shared aiohttp session.
params: Deprecated input parameters object.
settings: Runtime-updatable settings. When provided alongside deprecated
parameters, ``settings`` values take precedence.
settings: Runtime-updatable settings.
**kwargs: Additional keyword arguments passed to ``TTSService``.
"""
default_settings = self.Settings(
model=None,
voice="eve",
language="en",
language=Language.EN,
)
if voice is not None:
self._warn_init_param_moved_to_settings("voice", "voice")
default_settings.voice = voice
if language is not None:
self._warn_init_param_moved_to_settings("language", "language")
default_settings.language = (
self.language_to_service_language(language)
if isinstance(language, Language)
else language
)
if params is not None:
self._warn_init_param_moved_to_settings("params")
if not settings and params.language is not None:
default_settings.language = self.language_to_service_language(params.language)
if settings is not None:
default_settings.apply_update(settings)
super().__init__(
pause_frame_processing=True,
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
sample_rate=sample_rate or self.XAI_DEFAULT_SAMPLE_RATE,
settings=default_settings,
**kwargs,
)
@@ -130,14 +121,22 @@ class XAITTSService(TTSService):
self._base_url = base_url
self._session = aiohttp_session
self._session_owner = aiohttp_session is None
self._encoding = encoding
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics."""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to xAI's language format."""
return str(language)
"""Convert a Language enum to xAI language format.
Args:
language: The language to convert.
Returns:
The xAI-specific language code, or None if not supported.
"""
return language_to_xai_language(language)
async def start(self, frame):
"""Start the xAI TTS service."""
@@ -175,7 +174,7 @@ class XAITTSService(TTSService):
"text": text,
"voice_id": self._settings.voice,
"output_format": {
"codec": self.XAI_PCM_CODEC,
"codec": self._encoding,
"sample_rate": self.sample_rate,
},
}
@@ -189,12 +188,11 @@ class XAITTSService(TTSService):
measuring_ttfb = True
try:
async with self._session.post(self._base_url, json=payload, headers=headers) as response:
async with self._session.post(
self._base_url, json=payload, headers=headers
) as response:
if response.status != 200:
error = await response.text(errors="ignore")
logger.error(
f"{self} error getting audio (status: {response.status}, error: {error})"
)
yield ErrorFrame(
error=f"Error getting audio (status: {response.status}, error: {error})"
)
@@ -208,6 +206,11 @@ class XAITTSService(TTSService):
if measuring_ttfb:
await self.stop_ttfb_metrics()
measuring_ttfb = False
yield TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
yield TTSAudioRawFrame(
chunk,
self.sample_rate,
1,
context_id=context_id,
)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for XAITTSService."""
"""Tests for XAIHttpTTSService."""
import asyncio
import unittest
@@ -21,7 +21,7 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
TTSTextFrame,
)
from pipecat.services.xai.tts import XAITTSService
from pipecat.services.xai.tts import XAIHttpTTSService
from pipecat.tests.utils import run_test
@@ -52,7 +52,7 @@ async def test_run_xai_tts_success(aiohttp_client):
base_url = str(client.make_url("/v1/tts"))
async with aiohttp.ClientSession() as session:
tts_service = XAITTSService(
tts_service = XAIHttpTTSService(
api_key="test-key",
base_url=base_url,
aiohttp_session=session,