Merge pull request #3171 from LaurentMazare/gradium
Gradium integration.
This commit is contained in:
@@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
ensures all function calls complete before returning results to the LLM (i.e.,
|
||||
before running a new inference with those results).
|
||||
|
||||
- Added new Gradium services, `GradiumSTTService` and `GradiumTTSService`, for
|
||||
speech-to-text and text-to-speech functionality using Gradium's API.
|
||||
|
||||
### Changed
|
||||
|
||||
- If an unexpected exception is caught, or if `FrameProcessor.push_error()` is
|
||||
|
||||
@@ -74,9 +74,9 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Gradium](https://docs.pipecat.ai/server/services/stt/gradium), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Gradium](https://docs.pipecat.ai/server/services/tts/gradium), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Serializers | [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx) |
|
||||
|
||||
@@ -73,6 +73,9 @@ GOOGLE_CLOUD_PROJECT_ID=...
|
||||
GOOGLE_CLOUD_LOCATION=...
|
||||
GOOGLE_TEST_CREDENTIALS=...
|
||||
|
||||
# Gradium
|
||||
GRAPDIUM_API_KEY=...
|
||||
|
||||
# Grok
|
||||
GROK_API_KEY=...
|
||||
|
||||
@@ -191,4 +194,4 @@ TWILIO_AUTH_TOKEN=...
|
||||
WHATSAPP_TOKEN=...
|
||||
WHATSAPP_WEBHOOK_VERIFICATION_TOKEN=...
|
||||
WHATSAPP_PHONE_NUMBER_ID=...
|
||||
WHATSAPP_APP_SECRET=...
|
||||
WHATSAPP_APP_SECRET=...
|
||||
|
||||
127
examples/foundational/07af-interruptible-gradium.py
Normal file
127
examples/foundational/07af-interruptible-gradium.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gradium.stt import GradiumSTTService
|
||||
from pipecat.services.gradium.tts import GradiumTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
|
||||
|
||||
tts = GradiumTTSService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
voice_id="YTpq7expH9539ERJ",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -63,6 +63,7 @@ fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
|
||||
gladia = [ "pipecat-ai[websockets-base]" ]
|
||||
google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.41.0,<2", "pipecat-ai[websockets-base]" ]
|
||||
gradium = [ "pipecat-ai[websockets-base]" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
|
||||
5
src/pipecat/services/gradium/__init__.py
Normal file
5
src/pipecat/services/gradium/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
239
src/pipecat/services/gradium/stt.py
Normal file
239
src/pipecat/services/gradium/stt.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Gradium's speech-to-text service implementation.
|
||||
|
||||
This module provides integration with Gradium's real-time speech-to-text
|
||||
WebSocket API for streaming audio transcription.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error('In order to use Gradium, you need to `pip install "pipecat-ai[gradium]"`.')
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
class GradiumSTTService(WebsocketSTTService):
|
||||
"""Gradium real-time speech-to-text service.
|
||||
|
||||
Provides real-time speech transcription using Gradium's WebSocket API.
|
||||
Supports both interim and final transcriptions with configurable parameters
|
||||
for audio processing and connection management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
|
||||
json_config: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium STT service.
|
||||
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._websocket = None
|
||||
self._json_config = json_config
|
||||
|
||||
self._receive_task = None
|
||||
|
||||
self._audio_buffer = bytearray()
|
||||
self._chunk_size_ms = 80
|
||||
self._chunk_size_bytes = 0
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
Returns:
|
||||
True if metrics generation is supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: Start frame to begin processing.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: End frame to stop processing.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the speech-to-text service.
|
||||
|
||||
Args:
|
||||
frame: Cancel frame to abort processing.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text conversion.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to process.
|
||||
|
||||
Yields:
|
||||
None (processing handled via WebSocket messages).
|
||||
"""
|
||||
self._audio_buffer.extend(audio)
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
|
||||
chunk = base64.b64encode(chunk).decode("utf-8")
|
||||
msg = {"type": "audio", "audio": chunk}
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
yield None
|
||||
|
||||
@traced_stt
|
||||
async def _trace_transcription(self, transcript: str, is_final: bool, language: Language):
|
||||
"""Record transcription event for tracing."""
|
||||
pass
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
ws_url = self._api_endpoint_base_url
|
||||
headers = {
|
||||
"x-api-key": self._api_key,
|
||||
"x-api-source": "pipecat",
|
||||
}
|
||||
self._websocket = await websocket_connect(
|
||||
ws_url,
|
||||
additional_headers=headers,
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
setup_msg = {
|
||||
"type": "setup",
|
||||
"input_format": "pcm",
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
if ready_msg["type"] == "error":
|
||||
raise Exception(f"received error {ready_msg['message']}")
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
logger.debug("Disconnecting from Gradium STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _process_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_response(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
await self._process_messages()
|
||||
logger.debug(f"{self} Gradium connection was disconnected (timeout?), reconnecting")
|
||||
await self._connect_websocket()
|
||||
|
||||
async def _process_response(self, msg):
|
||||
type_ = msg.get("type", "")
|
||||
if type_ == "text":
|
||||
await self._handle_text(msg["text"])
|
||||
elif type_ == "end_of_stream":
|
||||
await self._handle_end_of_stream()
|
||||
elif type_ == "error":
|
||||
await self.push_error(error_msg=f"Error: {msg}")
|
||||
|
||||
async def _handle_end_of_stream(self):
|
||||
"""Handle termination message."""
|
||||
logger.debug("Received end_of_stream message from server")
|
||||
|
||||
async def _handle_text(self, text: str):
|
||||
"""Handle transcription results."""
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
)
|
||||
)
|
||||
315
src/pipecat/services/gradium/tts.py
Normal file
315
src/pipecat/services/gradium/tts.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
|
||||
"""Gradium Text-to-Speech service implementation."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import InterruptibleWordTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from websockets import ConnectionClosedOK
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Gradium, you need to `pip install pipecat-ai[gradium]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SAMPLE_RATE = 48000
|
||||
|
||||
|
||||
class GradiumTTSService(InterruptibleWordTTSService):
|
||||
"""Text-to-Speech service using Gradium's websocket API."""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Gradium TTS service.
|
||||
|
||||
Parameters:
|
||||
temp: Temperature to be used for generation, defaults to 0.6.
|
||||
"""
|
||||
|
||||
temp: Optional[float] = 0.6
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str = "YTpq7expH9539ERJ",
|
||||
url: str = "wss://eu.api.gradium.ai/api/speech/tts",
|
||||
model: str = "default",
|
||||
json_config: Optional[str] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
voice_id: the voice identifier.
|
||||
url: Gradium websocket API endpoint.
|
||||
model: Model ID to use for synthesis.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
params: Additional configuration parameters.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
params = params or GradiumTTSService.InputParams()
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._json_config = json_config
|
||||
self._model = model
|
||||
self._settings = {
|
||||
"voice_id": voice_id,
|
||||
"model_name": model,
|
||||
"output_format": "pcm",
|
||||
}
|
||||
|
||||
# State tracking
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Gradium service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Update the TTS model.
|
||||
|
||||
Args:
|
||||
model: The model name to use for synthesis.
|
||||
"""
|
||||
self._model = model
|
||||
await super().set_model(model)
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
if not prev_voice == self._voice_id:
|
||||
self._settings["voice_id"] = self._voice_id
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
def _build_msg(self, text: str = "") -> dict:
|
||||
"""Build JSON message for Gradium API."""
|
||||
return {"text": text, "type": "text"}
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish websocket connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and close connection.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel current operation and clean up.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish websocket connection and start receive task."""
|
||||
logger.debug(f"{self}: connecting")
|
||||
|
||||
# If the server disconnected, cancel the receive-task so that it can be reset below.
|
||||
if self._websocket is None or self._websocket.state is not State.OPEN:
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
logger.debug(f"{self}: setting receive task")
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close websocket connection and clean up tasks."""
|
||||
logger.debug(f"{self}: disconnecting")
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to Gradium websocket API with configured settings."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
headers = {"x-api-key": self._api_key, "x-api-source": "pipecat"}
|
||||
self._websocket = await websocket_connect(self._url, additional_headers=headers)
|
||||
|
||||
setup_msg = {
|
||||
"type": "setup",
|
||||
"output_format": "pcm",
|
||||
"voice_id": self._voice_id,
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
if ready_msg["type"] == "error":
|
||||
raise Exception(f"received error {ready_msg['message']}")
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close websocket connection and reset state."""
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
def _get_websocket(self):
|
||||
"""Get active websocket connection or raise exception."""
|
||||
if self._websocket:
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket:
|
||||
return
|
||||
try:
|
||||
msg = {"type": "end_of_stream"}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except ConnectionClosedOK:
|
||||
logger.debug(f"{self}: connection closed normally during flush")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Process incoming websocket messages."""
|
||||
# TODO(laurent): This should not be necessary as it should happen when
|
||||
# receiving the messages but this does not seem to always be the case
|
||||
# and that may lead to a busy polling loop.
|
||||
if self._websocket and self._websocket.state is State.CLOSED:
|
||||
raise ConnectionClosedOK(None, None)
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
|
||||
if msg["type"] == "audio":
|
||||
# Process audio chunk
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["audio"]),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
elif msg["type"] == "text":
|
||||
await self.add_word_timestamps([(msg["text"], msg["start_s"])])
|
||||
elif msg["type"] == "end_of_stream":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
|
||||
elif msg["type"] == "error":
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['message']}")
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push frame and handle end-of-turn conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Gradium's streaming API.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
_state = self._websocket.state if self._websocket is not None else None
|
||||
logger.debug(f"{self}: Generating TTS [{text}] {_state}")
|
||||
try:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
self._websocket = None
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
yield TTSStartedFrame()
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
Reference in New Issue
Block a user