Merge pull request #3171 from LaurentMazare/gradium

Gradium integration.
This commit is contained in:
Mark Backman
2025-12-05 09:43:44 -05:00
committed by GitHub
8 changed files with 696 additions and 3 deletions

View File

@@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
ensures all function calls complete before returning results to the LLM (i.e.,
before running a new inference with those results).
- Added new Gradium services, `GradiumSTTService` and `GradiumTTSService`, for
speech-to-text and text-to-speech functionality using Gradium's API.
### Changed
- If an unexpected exception is caught, or if `FrameProcessor.push_error()` is

View File

@@ -74,9 +74,9 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
| Category | Services |
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Gradium](https://docs.pipecat.ai/server/services/stt/gradium), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Gradium](https://docs.pipecat.ai/server/services/tts/gradium), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
| Serializers | [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx) |

View File

@@ -73,6 +73,9 @@ GOOGLE_CLOUD_PROJECT_ID=...
GOOGLE_CLOUD_LOCATION=...
GOOGLE_TEST_CREDENTIALS=...
# Gradium
GRAPDIUM_API_KEY=...
# Grok
GROK_API_KEY=...
@@ -191,4 +194,4 @@ TWILIO_AUTH_TOKEN=...
WHATSAPP_TOKEN=...
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN=...
WHATSAPP_PHONE_NUMBER_ID=...
WHATSAPP_APP_SECRET=...
WHATSAPP_APP_SECRET=...

View File

@@ -0,0 +1,127 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.gradium.stt import GradiumSTTService
from pipecat.services.gradium.tts import GradiumTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
tts = GradiumTTSService(
api_key=os.getenv("GRADIUM_API_KEY"),
voice_id="YTpq7expH9539ERJ",
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
},
]
context = LLMContext(messages)
context_aggregator = LLMContextAggregatorPair(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -63,6 +63,7 @@ fireworks = []
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
gladia = [ "pipecat-ai[websockets-base]" ]
google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.41.0,<2", "pipecat-ai[websockets-base]" ]
gradium = [ "pipecat-ai[websockets-base]" ]
grok = []
groq = [ "groq~=0.23.0" ]
gstreamer = [ "pygobject~=3.50.0" ]

View File

@@ -0,0 +1,5 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

View File

@@ -0,0 +1,239 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Gradium's speech-to-text service implementation.
This module provides integration with Gradium's real-time speech-to-text
WebSocket API for streaming audio transcription.
"""
import base64
import json
from typing import AsyncGenerator
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.stt_service import WebsocketSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error('In order to use Gradium, you need to `pip install "pipecat-ai[gradium]"`.')
raise Exception(f"Missing module: {e}")
SAMPLE_RATE = 24000
class GradiumSTTService(WebsocketSTTService):
"""Gradium real-time speech-to-text service.
Provides real-time speech transcription using Gradium's WebSocket API.
Supports both interim and final transcriptions with configurable parameters
for audio processing and connection management.
"""
def __init__(
self,
*,
api_key: str,
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
json_config: str | None = None,
**kwargs,
):
"""Initialize the Gradium STT service.
Args:
api_key: Gradium API key for authentication.
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
json_config: Optional JSON configuration string for additional model settings.
**kwargs: Additional arguments passed to parent STTService class.
"""
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
self._api_key = api_key
self._api_endpoint_base_url = api_endpoint_base_url
self._websocket = None
self._json_config = json_config
self._receive_task = None
self._audio_buffer = bytearray()
self._chunk_size_ms = 80
self._chunk_size_bytes = 0
def can_generate_metrics(self) -> bool:
"""Check if the service can generate metrics.
Returns:
True if metrics generation is supported.
"""
return True
async def start(self, frame: StartFrame):
"""Start the speech-to-text service.
Args:
frame: Start frame to begin processing.
"""
await super().start(frame)
self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the speech-to-text service.
Args:
frame: End frame to stop processing.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the speech-to-text service.
Args:
frame: Cancel frame to abort processing.
"""
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process audio data for speech-to-text conversion.
Args:
audio: Raw audio bytes to process.
Yields:
None (processing handled via WebSocket messages).
"""
self._audio_buffer.extend(audio)
await self.start_ttfb_metrics()
await self.start_processing_metrics()
while len(self._audio_buffer) >= self._chunk_size_bytes:
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
chunk = base64.b64encode(chunk).decode("utf-8")
msg = {"type": "audio", "audio": chunk}
if self._websocket and self._websocket.state is State.OPEN:
await self._websocket.send(json.dumps(msg))
yield None
@traced_stt
async def _trace_transcription(self, transcript: str, is_final: bool, language: Language):
"""Record transcription event for tracing."""
pass
async def _connect(self):
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _connect_websocket(self):
try:
if self._websocket and self._websocket.state is State.OPEN:
return
ws_url = self._api_endpoint_base_url
headers = {
"x-api-key": self._api_key,
"x-api-source": "pipecat",
}
self._websocket = await websocket_connect(
ws_url,
additional_headers=headers,
)
await self._call_event_handler("on_connected")
setup_msg = {
"type": "setup",
"input_format": "pcm",
}
if self._json_config is not None:
setup_msg["json_config"] = self._json_config
await self._websocket.send(json.dumps(setup_msg))
ready_msg = await self._websocket.recv()
ready_msg = json.loads(ready_msg)
if ready_msg["type"] == "error":
raise Exception(f"received error {ready_msg['message']}")
if ready_msg["type"] != "ready":
raise Exception(f"unexpected first message type {ready_msg['type']}")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
raise
async def _disconnect(self):
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _disconnect_websocket(self):
try:
if self._websocket and self._websocket.state is State.OPEN:
logger.debug("Disconnecting from Gradium STT")
await self._websocket.close()
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _process_messages(self):
async for message in self._get_websocket():
try:
data = json.loads(message)
await self._process_response(data)
except json.JSONDecodeError:
logger.warning(f"Received non-JSON message: {message}")
async def _receive_messages(self):
while True:
await self._process_messages()
logger.debug(f"{self} Gradium connection was disconnected (timeout?), reconnecting")
await self._connect_websocket()
async def _process_response(self, msg):
type_ = msg.get("type", "")
if type_ == "text":
await self._handle_text(msg["text"])
elif type_ == "end_of_stream":
await self._handle_end_of_stream()
elif type_ == "error":
await self.push_error(error_msg=f"Error: {msg}")
async def _handle_end_of_stream(self):
"""Handle termination message."""
logger.debug("Received end_of_stream message from server")
async def _handle_text(self, text: str):
"""Handle transcription results."""
await self.push_frame(
TranscriptionFrame(
text,
self._user_id,
time_now_iso8601(),
)
)

View File

@@ -0,0 +1,315 @@
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
"""Gradium Text-to-Speech service implementation."""
import base64
import json
import uuid
from typing import Any, AsyncGenerator, Mapping, Optional
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import InterruptibleWordTTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
from websockets import ConnectionClosedOK
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Gradium, you need to `pip install pipecat-ai[gradium]`.")
raise Exception(f"Missing module: {e}")
SAMPLE_RATE = 48000
class GradiumTTSService(InterruptibleWordTTSService):
"""Text-to-Speech service using Gradium's websocket API."""
class InputParams(BaseModel):
"""Configuration parameters for Gradium TTS service.
Parameters:
temp: Temperature to be used for generation, defaults to 0.6.
"""
temp: Optional[float] = 0.6
def __init__(
self,
*,
api_key: str,
voice_id: str = "YTpq7expH9539ERJ",
url: str = "wss://eu.api.gradium.ai/api/speech/tts",
model: str = "default",
json_config: Optional[str] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Gradium TTS service.
Args:
api_key: Gradium API key for authentication.
voice_id: the voice identifier.
url: Gradium websocket API endpoint.
model: Model ID to use for synthesis.
json_config: Optional JSON configuration string for additional model settings.
params: Additional configuration parameters.
**kwargs: Additional arguments passed to parent class.
"""
# Initialize with parent class settings for proper frame handling
super().__init__(
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=SAMPLE_RATE,
**kwargs,
)
params = params or GradiumTTSService.InputParams()
# Store service configuration
self._api_key = api_key
self._url = url
self._voice_id = voice_id
self._json_config = json_config
self._model = model
self._settings = {
"voice_id": voice_id,
"model_name": model,
"output_format": "pcm",
}
# State tracking
self._receive_task = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Gradium service supports metrics generation.
"""
return True
async def set_model(self, model: str):
"""Update the TTS model.
Args:
model: The model name to use for synthesis.
"""
self._model = model
await super().set_model(model)
async def _update_settings(self, settings: Mapping[str, Any]):
"""Update service settings and reconnect if voice changed."""
prev_voice = self._voice_id
await super()._update_settings(settings)
if not prev_voice == self._voice_id:
self._settings["voice_id"] = self._voice_id
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
await self._disconnect()
await self._connect()
def _build_msg(self, text: str = "") -> dict:
"""Build JSON message for Gradium API."""
return {"text": text, "type": "text"}
async def start(self, frame: StartFrame):
"""Start the service and establish websocket connection.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the service and close connection.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel current operation and clean up.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
"""Establish websocket connection and start receive task."""
logger.debug(f"{self}: connecting")
# If the server disconnected, cancel the receive-task so that it can be reset below.
if self._websocket is None or self._websocket.state is not State.OPEN:
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._connect_websocket()
if self._websocket and not self._receive_task:
logger.debug(f"{self}: setting receive task")
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Close websocket connection and clean up tasks."""
logger.debug(f"{self}: disconnecting")
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
"""Connect to Gradium websocket API with configured settings."""
try:
if self._websocket and self._websocket.state is State.OPEN:
return
headers = {"x-api-key": self._api_key, "x-api-source": "pipecat"}
self._websocket = await websocket_connect(self._url, additional_headers=headers)
setup_msg = {
"type": "setup",
"output_format": "pcm",
"voice_id": self._voice_id,
}
if self._json_config is not None:
setup_msg["json_config"] = self._json_config
await self._websocket.send(json.dumps(setup_msg))
ready_msg = await self._websocket.recv()
ready_msg = json.loads(ready_msg)
if ready_msg["type"] == "error":
raise Exception(f"received error {ready_msg['message']}")
if ready_msg["type"] != "ready":
raise Exception(f"unexpected first message type {ready_msg['type']}")
await self._call_event_handler("on_connected")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
"""Close websocket connection and reset state."""
try:
await self.stop_all_metrics()
if self._websocket:
await self._websocket.close()
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
def _get_websocket(self):
"""Get active websocket connection or raise exception."""
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def flush_audio(self):
"""Flush any pending audio synthesis."""
if not self._websocket:
return
try:
msg = {"type": "end_of_stream"}
await self._websocket.send(json.dumps(msg))
except ConnectionClosedOK:
logger.debug(f"{self}: connection closed normally during flush")
except Exception as e:
logger.error(f"{self} exception: {e}")
async def _receive_messages(self):
"""Process incoming websocket messages."""
# TODO(laurent): This should not be necessary as it should happen when
# receiving the messages but this does not seem to always be the case
# and that may lead to a busy polling loop.
if self._websocket and self._websocket.state is State.CLOSED:
raise ConnectionClosedOK(None, None)
async for message in self._get_websocket():
msg = json.loads(message)
if msg["type"] == "audio":
# Process audio chunk
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["audio"]),
sample_rate=self.sample_rate,
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "text":
await self.add_word_timestamps([(msg["text"], msg["start_s"])])
elif msg["type"] == "end_of_stream":
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
elif msg["type"] == "error":
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(error_msg=f"Error: {msg['message']}")
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push frame and handle end-of-turn conditions.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Gradium's streaming API.
Args:
text: The text to convert to speech.
Yields:
Frame: Audio frames containing the synthesized speech.
"""
_state = self._websocket.state if self._websocket is not None else None
logger.debug(f"{self}: Generating TTS [{text}] {_state}")
try:
if not self._websocket or self._websocket.state is State.CLOSED:
self._websocket = None
await self._connect()
try:
yield TTSStartedFrame()
msg = self._build_msg(text=text)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")