diff --git a/changelog/4031.added.md b/changelog/4031.added.md new file mode 100644 index 000000000..0c228a24a --- /dev/null +++ b/changelog/4031.added.md @@ -0,0 +1 @@ +- Added `XAIHttpTTSService` for text-to-speech using xAI's HTTP TTS API. diff --git a/examples/foundational/07e-interruptible-xai.py b/examples/foundational/07e-interruptible-xai.py new file mode 100644 index 000000000..762c27e7c --- /dev/null +++ b/examples/foundational/07e-interruptible-xai.py @@ -0,0 +1,129 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os + +import aiohttp +from dotenv import load_dotenv +from loguru import logger + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMRunFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.grok.llm import GrokLLMService +from pipecat.services.xai.tts import XAIHttpTTSService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + +# We use lambdas to defer transport parameter creation until the transport +# type is selected at runtime. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info(f"Starting bot") + + async with aiohttp.ClientSession() as session: + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = XAIHttpTTSService( + api_key=os.getenv("GROK_API_KEY"), + aiohttp_session=session, + settings=XAIHttpTTSService.Settings( + voice="eve", + ), + ) + + llm = GrokLLMService( + api_key=os.getenv("GROK_API_KEY"), + settings=GrokLLMService.Settings( + system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.", + ), + ) + + context = LLMContext() + user_aggregator, assistant_aggregator = LLMContextAggregatorPair( + context, + user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()), + ) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + user_aggregator, # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + assistant_aggregator, # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + audio_out_sample_rate=8000, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + context.add_message( + {"role": "user", "content": "Please introduce yourself to the user."} + ) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/examples/foundational/14g-function-calling-grok.py b/examples/foundational/14g-function-calling-grok.py index 4de1f6528..34ec82b5f 100644 --- a/examples/foundational/14g-function-calling-grok.py +++ b/examples/foundational/14g-function-calling-grok.py @@ -7,6 +7,7 @@ import os +import aiohttp from dotenv import load_dotenv from loguru import logger @@ -24,10 +25,10 @@ from pipecat.processors.aggregators.llm_response_universal import ( ) from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport -from pipecat.services.cartesia.tts import CartesiaTTSService from pipecat.services.deepgram.stt import DeepgramSTTService from pipecat.services.grok.llm import GrokLLMService from pipecat.services.llm_service import FunctionCallParams +from pipecat.services.xai.tts import XAIHttpTTSService from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams @@ -60,83 +61,88 @@ transport_params = { async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): logger.info(f"Starting bot") - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + async with aiohttp.ClientSession() as session: + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) - tts = CartesiaTTSService( - api_key=os.getenv("CARTESIA_API_KEY"), - settings=CartesiaTTSService.Settings( - voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady - ), - ) + tts = XAIHttpTTSService( + api_key=os.getenv("GROK_API_KEY"), + aiohttp_session=session, + settings=XAIHttpTTSService.Settings( + voice="eve", + ), + ) - llm = GrokLLMService( - api_key=os.getenv("GROK_API_KEY"), - settings=GrokLLMService.Settings( - system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.", - ), - ) - # You can also register a function_name of None to get all functions - # sent to the same callback with an additional function_name parameter. - llm.register_function("get_current_weather", fetch_weather_from_api) + llm = GrokLLMService( + api_key=os.getenv("GROK_API_KEY"), + settings=GrokLLMService.Settings( + system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.", + ), + ) + # You can also register a function_name of None to get all functions + # sent to the same callback with an additional function_name parameter. + llm.register_function("get_current_weather", fetch_weather_from_api) - weather_function = FunctionSchema( - name="get_current_weather", - description="Get the current weather", - properties={ - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + weather_function = FunctionSchema( + name="get_current_weather", + description="Get the current weather", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location.", + }, }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the user's location.", - }, - }, - required=["location", "format"], - ) - tools = ToolsSchema(standard_tools=[weather_function]) - context = LLMContext(tools=tools) - user_aggregator, assistant_aggregator = LLMContextAggregatorPair( - context, - user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()), - ) + required=["location", "format"], + ) + tools = ToolsSchema(standard_tools=[weather_function]) + context = LLMContext(tools=tools) + user_aggregator, assistant_aggregator = LLMContextAggregatorPair( + context, + user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()), + ) - pipeline = Pipeline( - [ - transport.input(), - stt, - user_aggregator, - llm, - tts, - transport.output(), - assistant_aggregator, - ] - ) + pipeline = Pipeline( + [ + transport.input(), + stt, + user_aggregator, + llm, + tts, + transport.output(), + assistant_aggregator, + ] + ) - task = PipelineTask( - pipeline, - params=PipelineParams( - enable_metrics=True, - enable_usage_metrics=True, - ), - idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, - ) + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) - @transport.event_handler("on_client_connected") - async def on_client_connected(transport, client): - logger.info(f"Client connected") - # Kick off the conversation. - await task.queue_frames([LLMRunFrame()]) + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + context.add_message( + {"role": "user", "content": "Please introduce yourself to the user."} + ) + await task.queue_frames([LLMRunFrame()]) - @transport.event_handler("on_client_disconnected") - async def on_client_disconnected(transport, client): - logger.info(f"Client disconnected") - await task.cancel() + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() - runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) - await runner.run(task) + await runner.run(task) async def bot(runner_args: RunnerArguments): diff --git a/src/pipecat/services/xai/tts.py b/src/pipecat/services/xai/tts.py index 37a3db80a..a7026b2ae 100644 --- a/src/pipecat/services/xai/tts.py +++ b/src/pipecat/services/xai/tts.py @@ -15,57 +15,77 @@ from typing import AsyncGenerator, Optional import aiohttp from loguru import logger -from pydantic import BaseModel from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame from pipecat.services.settings import TTSSettings from pipecat.services.tts_service import TTSService -from pipecat.transcriptions.language import Language +from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts +def language_to_xai_language(language: Language) -> Optional[str]: + """Convert a Language enum to xAI language code. + + Args: + language: The Language enum value to convert. + + Returns: + The corresponding xAI language code, or None if not supported. + """ + LANGUAGE_MAP = { + Language.AR: "ar-EG", + Language.AR_EG: "ar-EG", + Language.AR_SA: "ar-SA", + Language.AR_AE: "ar-AE", + Language.BN: "bn", + Language.DE: "de", + Language.EN: "en", + Language.ES: "es-ES", + Language.ES_ES: "es-ES", + Language.ES_MX: "es-MX", + Language.FR: "fr", + Language.HI: "hi", + Language.ID: "id", + Language.IT: "it", + Language.JA: "ja", + Language.KO: "ko", + Language.PT: "pt-PT", + Language.PT_BR: "pt-BR", + Language.PT_PT: "pt-PT", + Language.RU: "ru", + Language.TR: "tr", + Language.VI: "vi", + Language.ZH: "zh", + } + + return resolve_language(language, LANGUAGE_MAP, use_base_code=True) + + @dataclass -class XAITTSSettings(TTSSettings): - """Settings for XAITTSService.""" +class XAIHttpTTSSettings(TTSSettings): + """Settings for XAIHttpTTSService.""" pass -class XAITTSService(TTSService): +class XAIHttpTTSService(TTSService): """xAI HTTP text-to-speech service. The service requests raw PCM audio so emitted ``TTSAudioRawFrame`` objects match Pipecat's downstream expectations without extra decoding. """ - Settings = XAITTSSettings + Settings = XAIHttpTTSSettings _settings: Settings - XAI_DEFAULT_SAMPLE_RATE = 24000 - XAI_PCM_CODEC = "pcm" - - class InputParams(BaseModel): - """Input parameters for xAI TTS configuration. - - .. deprecated:: 0.0.105 - Use ``settings=XAITTSService.Settings(...)`` instead. - - Parameters: - language: Language for speech synthesis. - """ - - language: Optional[Language] = None - def __init__( self, *, api_key: str, base_url: str = "https://api.x.ai/v1/tts", - voice: Optional[str] = None, - language: Optional[str | Language] = None, sample_rate: Optional[int] = None, + encoding: Optional[str] = "pcm", aiohttp_session: Optional[aiohttp.ClientSession] = None, - params: Optional[InputParams] = None, settings: Optional[Settings] = None, **kwargs, ): @@ -74,54 +94,25 @@ class XAITTSService(TTSService): Args: api_key: xAI API key for authentication. base_url: xAI TTS endpoint. Defaults to ``https://api.x.ai/v1/tts``. - voice: Voice identifier. Defaults to ``"eve"``. - - .. deprecated:: 0.0.105 - Use ``settings=XAITTSService.Settings(voice=...)`` instead. - - language: BCP-47 or base language code (for example ``"en"`` or ``"pt-BR"``). - Defaults to ``"en"``. - - .. deprecated:: 0.0.105 - Use ``settings=XAITTSService.Settings(language=...)`` instead. - - sample_rate: Output sample rate for PCM audio. Defaults to 24000 Hz. + sample_rate: Audio sample rate. If None, uses default. + encoding: Output encoding format. Defaults to "pcm". aiohttp_session: Optional shared aiohttp session. - params: Deprecated input parameters object. - settings: Runtime-updatable settings. When provided alongside deprecated - parameters, ``settings`` values take precedence. + settings: Runtime-updatable settings. **kwargs: Additional keyword arguments passed to ``TTSService``. """ default_settings = self.Settings( model=None, voice="eve", - language="en", + language=Language.EN, ) - if voice is not None: - self._warn_init_param_moved_to_settings("voice", "voice") - default_settings.voice = voice - if language is not None: - self._warn_init_param_moved_to_settings("language", "language") - default_settings.language = ( - self.language_to_service_language(language) - if isinstance(language, Language) - else language - ) - - if params is not None: - self._warn_init_param_moved_to_settings("params") - if not settings and params.language is not None: - default_settings.language = self.language_to_service_language(params.language) - if settings is not None: default_settings.apply_update(settings) super().__init__( - pause_frame_processing=True, + sample_rate=sample_rate, push_start_frame=True, push_stop_frames=True, - sample_rate=sample_rate or self.XAI_DEFAULT_SAMPLE_RATE, settings=default_settings, **kwargs, ) @@ -130,14 +121,22 @@ class XAITTSService(TTSService): self._base_url = base_url self._session = aiohttp_session self._session_owner = aiohttp_session is None + self._encoding = encoding def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics.""" return True def language_to_service_language(self, language: Language) -> Optional[str]: - """Convert a Language enum to xAI's language format.""" - return str(language) + """Convert a Language enum to xAI language format. + + Args: + language: The language to convert. + + Returns: + The xAI-specific language code, or None if not supported. + """ + return language_to_xai_language(language) async def start(self, frame): """Start the xAI TTS service.""" @@ -175,7 +174,7 @@ class XAITTSService(TTSService): "text": text, "voice_id": self._settings.voice, "output_format": { - "codec": self.XAI_PCM_CODEC, + "codec": self._encoding, "sample_rate": self.sample_rate, }, } @@ -189,12 +188,11 @@ class XAITTSService(TTSService): measuring_ttfb = True try: - async with self._session.post(self._base_url, json=payload, headers=headers) as response: + async with self._session.post( + self._base_url, json=payload, headers=headers + ) as response: if response.status != 200: error = await response.text(errors="ignore") - logger.error( - f"{self} error getting audio (status: {response.status}, error: {error})" - ) yield ErrorFrame( error=f"Error getting audio (status: {response.status}, error: {error})" ) @@ -208,6 +206,11 @@ class XAITTSService(TTSService): if measuring_ttfb: await self.stop_ttfb_metrics() measuring_ttfb = False - yield TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id) + yield TTSAudioRawFrame( + chunk, + self.sample_rate, + 1, + context_id=context_id, + ) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}") diff --git a/tests/test_xai_tts.py b/tests/test_xai_tts.py index b4c1513f6..aab984567 100644 --- a/tests/test_xai_tts.py +++ b/tests/test_xai_tts.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""Tests for XAITTSService.""" +"""Tests for XAIHttpTTSService.""" import asyncio import unittest @@ -21,7 +21,7 @@ from pipecat.frames.frames import ( TTSStoppedFrame, TTSTextFrame, ) -from pipecat.services.xai.tts import XAITTSService +from pipecat.services.xai.tts import XAIHttpTTSService from pipecat.tests.utils import run_test @@ -52,7 +52,7 @@ async def test_run_xai_tts_success(aiohttp_client): base_url = str(client.make_url("/v1/tts")) async with aiohttp.ClientSession() as session: - tts_service = XAITTSService( + tts_service = XAIHttpTTSService( api_key="test-key", base_url=base_url, aiohttp_session=session,