Merge pull request #4341 from pipecat-ai/mb/xai-tts
Add XAITTSService for xAI streaming WebSocket TTS
This commit is contained in:
1
changelog/4341.added.md
Normal file
1
changelog/4341.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `XAITTSService` for streaming text-to-speech using xAI's WebSocket TTS endpoint (`wss://api.x.ai/v1/tts`). Streams `text.delta` chunks up and base64 `audio.delta` chunks down on the same connection so audio begins flowing before the full utterance finishes synthesizing; complements the batch-HTTP `XAIHttpTTSService`. Defaults to raw PCM output so `TTSAudioRawFrame` needs no decoding. The `xai` optional extra now pulls in `pipecat-ai[websockets-base]`.
|
||||
128
examples/voice/voice-xai-http.py
Normal file
128
examples/voice/voice-xai-http.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.xai.llm import GrokLLMService
|
||||
from pipecat.services.xai.tts import XAIHttpTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = XAIHttpTTSService(
|
||||
api_key=os.getenv("XAI_API_KEY"),
|
||||
aiohttp_session=session,
|
||||
settings=XAIHttpTTSService.Settings(
|
||||
voice="eve",
|
||||
),
|
||||
)
|
||||
|
||||
llm = GrokLLMService(
|
||||
api_key=os.getenv("XAI_API_KEY"),
|
||||
settings=GrokLLMService.Settings(
|
||||
system_instruction="You are a helpful assistant in a voice conversation. Your responses will be spoken aloud, so avoid emojis, bullet points, or other formatting that can't be spoken. Respond to what the user said in a creative, helpful, and brief way.",
|
||||
),
|
||||
)
|
||||
|
||||
context = LLMContext()
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
context.add_message(
|
||||
{"role": "developer", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -24,7 +24,7 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.xai.llm import GrokLLMService
|
||||
from pipecat.services.xai.stt import XAISTTService
|
||||
from pipecat.services.xai.tts import XAIHttpTTSService
|
||||
from pipecat.services.xai.tts import XAITTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -55,10 +55,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = XAISTTService(api_key=os.environ["XAI_API_KEY"])
|
||||
|
||||
tts = XAIHttpTTSService(
|
||||
tts = XAITTSService(
|
||||
api_key=os.environ["XAI_API_KEY"],
|
||||
aiohttp_session=session,
|
||||
settings=XAIHttpTTSService.Settings(
|
||||
settings=XAITTSService.Settings(
|
||||
voice="eve",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -108,6 +108,7 @@ TESTS_VOICE = [
|
||||
("voice/voice-elevenlabs.py", EVAL_SIMPLE_MATH),
|
||||
("voice/voice-elevenlabs-http.py", EVAL_SIMPLE_MATH),
|
||||
("voice/voice-xai.py", EVAL_SIMPLE_MATH),
|
||||
("voice/voice-xai-http.py", EVAL_SIMPLE_MATH),
|
||||
("voice/voice-azure.py", EVAL_SIMPLE_MATH),
|
||||
("voice/voice-azure-http.py", EVAL_SIMPLE_MATH),
|
||||
("voice/voice-openai.py", EVAL_SIMPLE_MATH),
|
||||
|
||||
@@ -6,22 +6,48 @@
|
||||
|
||||
"""xAI text-to-speech service implementation.
|
||||
|
||||
Uses xAI's HTTP TTS endpoint documented at:
|
||||
https://docs.x.ai/developers/model-capabilities/audio/text-to-speech
|
||||
Provides two TTS services against xAI's voice API:
|
||||
|
||||
- :class:`XAIHttpTTSService` uses the batch HTTP endpoint at
|
||||
``https://api.x.ai/v1/tts``.
|
||||
- :class:`XAITTSService` uses the streaming WebSocket endpoint at
|
||||
``wss://api.x.ai/v1/tts``.
|
||||
|
||||
See https://docs.x.ai/developers/rest-api-reference/inference/voice.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TTSAudioRawFrame
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use XAITTSService, you need to `pip install pipecat-ai[xai]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_xai_language(language: Language) -> str | None:
|
||||
"""Convert a Language enum to xAI language code.
|
||||
@@ -214,3 +240,249 @@ class XAIHttpTTSService(TTSService):
|
||||
)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class XAIWebsocketTTSSettings(TTSSettings):
|
||||
"""Settings for XAITTSService (WebSocket streaming)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class XAITTSService(InterruptibleTTSService):
|
||||
"""xAI streaming text-to-speech service.
|
||||
|
||||
Connects to xAI's WebSocket TTS endpoint and streams audio chunks back as
|
||||
they are synthesized. Text can be sent incrementally via ``text.delta``
|
||||
messages and each utterance is terminated with ``text.done``. The server
|
||||
responds with ``audio.delta`` chunks followed by an ``audio.done`` message.
|
||||
|
||||
Audio parameters (voice, language, codec, sample rate, bit rate) are passed
|
||||
as query string parameters on the WebSocket URL; changing any of them at
|
||||
runtime reconnects the WebSocket.
|
||||
"""
|
||||
|
||||
Settings = XAIWebsocketTTSSettings
|
||||
_settings: Settings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "wss://api.x.ai/v1/tts",
|
||||
sample_rate: int | None = None,
|
||||
codec: str = "pcm",
|
||||
settings: Settings | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the xAI WebSocket TTS service.
|
||||
|
||||
Args:
|
||||
api_key: xAI API key for authentication.
|
||||
base_url: xAI TTS WebSocket endpoint. Defaults to
|
||||
``wss://api.x.ai/v1/tts``.
|
||||
sample_rate: Output audio sample rate in Hz. If None, uses the
|
||||
pipeline default.
|
||||
codec: Output audio codec. One of ``pcm``, ``wav``, ``mulaw``,
|
||||
``alaw``. Defaults to ``pcm`` so emitted ``TTSAudioRawFrame``
|
||||
objects need no decoding downstream.
|
||||
settings: Runtime-updatable settings.
|
||||
**kwargs: Additional arguments passed to parent
|
||||
``InterruptibleTTSService``.
|
||||
"""
|
||||
default_settings = self.Settings(
|
||||
model=None,
|
||||
voice="eve",
|
||||
language=Language.EN,
|
||||
)
|
||||
|
||||
if settings is not None:
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
super().__init__(
|
||||
push_start_frame=True,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self._codec = codec
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics."""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
"""Convert a Language enum to xAI language format."""
|
||||
return language_to_xai_language(language)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the xAI WebSocket TTS service."""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the xAI WebSocket TTS service."""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the xAI WebSocket TTS service."""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
await super()._connect()
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
await super()._disconnect()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta. Reconnects if any URL-baked field changes."""
|
||||
changed = await super()._update_settings(delta)
|
||||
|
||||
if changed:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
return changed
|
||||
|
||||
def _build_url(self) -> str:
|
||||
language = self._settings.language
|
||||
if isinstance(language, Language):
|
||||
language_value = language_to_xai_language(language) or language.value
|
||||
else:
|
||||
language_value = str(language) if language is not None else "auto"
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"voice": self._settings.voice,
|
||||
"language": language_value,
|
||||
"codec": self._codec,
|
||||
"sample_rate": self.sample_rate,
|
||||
}
|
||||
return f"{self._base_url}?{urlencode(params)}"
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to xAI TTS")
|
||||
|
||||
url = self._build_url()
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from xAI TTS")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error disconnecting from xAI TTS: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self, context_id: str | None = None):
|
||||
"""Signal end-of-utterance so xAI begins synthesizing what it has buffered."""
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
return
|
||||
await self._get_websocket().send(json.dumps({"type": "text.done"}))
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
if isinstance(message, bytes):
|
||||
logger.warning(f"{self}: unexpected binary frame from xAI TTS")
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"{self}: invalid JSON message: {message}")
|
||||
continue
|
||||
|
||||
msg_type = msg.get("type")
|
||||
context_id = self.get_active_audio_context_id()
|
||||
|
||||
if msg_type == "audio.delta":
|
||||
audio_b64 = msg.get("delta")
|
||||
if not audio_b64:
|
||||
continue
|
||||
audio = base64.b64decode(audio_b64)
|
||||
await self.stop_ttfb_metrics()
|
||||
if context_id:
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
elif msg_type == "audio.done":
|
||||
await self.stop_all_metrics()
|
||||
if context_id:
|
||||
await self.append_to_audio_context(
|
||||
context_id, TTSStoppedFrame(context_id=context_id)
|
||||
)
|
||||
await self.remove_audio_context(context_id)
|
||||
elif msg_type == "error":
|
||||
await self.stop_all_metrics()
|
||||
error_detail = msg.get("message") or msg.get("error") or str(msg)
|
||||
if context_id:
|
||||
await self.append_to_audio_context(
|
||||
context_id, TTSStoppedFrame(context_id=context_id)
|
||||
)
|
||||
await self.remove_audio_context(context_id)
|
||||
await self.push_error(error_msg=f"xAI TTS error: {error_detail}")
|
||||
else:
|
||||
logger.debug(f"{self}: unhandled xAI message type: {msg_type}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate TTS audio from text using xAI's streaming WebSocket API."""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self._get_websocket().send(json.dumps({"type": "text.delta", "delta": text}))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame(context_id=context_id)
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
@@ -4,14 +4,19 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for XAIHttpTTSService."""
|
||||
"""Tests for XAIHttpTTSService and XAITTSService."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import unittest
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import websockets
|
||||
from aiohttp import web
|
||||
from websockets.asyncio.server import serve
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
@@ -21,7 +26,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.services.xai.tts import XAIHttpTTSService
|
||||
from pipecat.services.xai.tts import XAIHttpTTSService, XAITTSService
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
@@ -87,5 +92,87 @@ async def test_run_xai_tts_success(aiohttp_client):
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_xai_websocket_tts_success():
|
||||
"""xAI WS TTS should send text.delta+text.done and emit frames from audio.delta+audio.done."""
|
||||
|
||||
captured: dict = {
|
||||
"request_path": None,
|
||||
"auth_header": None,
|
||||
"messages": [],
|
||||
}
|
||||
|
||||
audio_bytes = b"\x00\x01\x02\x03" * 1024
|
||||
|
||||
async def handler(ws):
|
||||
request = ws.request
|
||||
captured["request_path"] = request.path
|
||||
captured["auth_header"] = request.headers.get("Authorization")
|
||||
|
||||
try:
|
||||
async for raw in ws:
|
||||
msg = json.loads(raw)
|
||||
captured["messages"].append(msg)
|
||||
if msg.get("type") == "text.done":
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "audio.delta",
|
||||
"delta": base64.b64encode(audio_bytes).decode("ascii"),
|
||||
}
|
||||
)
|
||||
)
|
||||
await ws.send(json.dumps({"type": "audio.done", "trace_id": "test-trace"}))
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
|
||||
async with serve(handler, "127.0.0.1", 0) as server:
|
||||
host, port = next(iter(server.sockets)).getsockname()[:2]
|
||||
base_url = f"ws://{host}:{port}/v1/tts"
|
||||
|
||||
tts_service = XAITTSService(
|
||||
api_key="test-key",
|
||||
base_url=base_url,
|
||||
sample_rate=24000,
|
||||
)
|
||||
|
||||
down_frames, _ = await run_test(
|
||||
tts_service,
|
||||
frames_to_send=[TTSSpeakFrame(text="Hello from xAI."), _SleepAfterSpeak(0.3)],
|
||||
)
|
||||
|
||||
frame_types = [type(frame) for frame in down_frames]
|
||||
assert TTSStartedFrame in frame_types
|
||||
assert TTSAudioRawFrame in frame_types
|
||||
assert TTSStoppedFrame in frame_types
|
||||
|
||||
audio_frames = [frame for frame in down_frames if isinstance(frame, TTSAudioRawFrame)]
|
||||
assert audio_frames
|
||||
assert all(frame.sample_rate == 24000 for frame in audio_frames)
|
||||
assert all(frame.num_channels == 1 for frame in audio_frames)
|
||||
assert b"".join(f.audio for f in audio_frames) == audio_bytes
|
||||
|
||||
assert captured["auth_header"] == "Bearer test-key"
|
||||
parsed = urlparse(captured["request_path"])
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["voice"] == ["eve"]
|
||||
assert query["language"] == ["en"]
|
||||
assert query["codec"] == ["pcm"]
|
||||
assert query["sample_rate"] == ["24000"]
|
||||
|
||||
types_sent = [m.get("type") for m in captured["messages"]]
|
||||
assert "text.delta" in types_sent
|
||||
assert "text.done" in types_sent
|
||||
delta_msg = next(m for m in captured["messages"] if m.get("type") == "text.delta")
|
||||
assert delta_msg["delta"] == "Hello from xAI."
|
||||
|
||||
|
||||
# Small helper imported lazily to avoid circular import in fixture-lite tests.
|
||||
def _SleepAfterSpeak(duration: float):
|
||||
from pipecat.tests.utils import SleepFrame
|
||||
|
||||
return SleepFrame(sleep=duration)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user