Merge pull request #1163 from pipecat-ai/mb/rime-websocket

Add RimeTTSService
This commit is contained in:
Mark Backman
2025-02-12 09:51:56 -05:00
committed by GitHub
4 changed files with 335 additions and 6 deletions

View File

@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `RimeTTSService`, a new `WordTTSService`. Updated the foundational
example `07q-interruptible-rime.py` to use `RimeTTSService`.
- Added support for Groq's Whisper API through the new `GroqSTTService` and
OpenAI's Whisper API through the new `OpenAISTTService`. Introduced a new
base class `BaseWhisperSTTService` to handle common Whisper API
@@ -22,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- `RimeHttpTTSService` now uses the `mistv2` model by default.
- Improved error handling in `AzureTTSService` to properly detect and log
synthesis cancellation errors.

View File

@@ -19,7 +19,7 @@ from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai import OpenAILLMService
from pipecat.services.rime import RimeHttpTTSService
from pipecat.services.rime import RimeTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
@@ -44,10 +44,9 @@ async def main():
),
)
tts = RimeHttpTTSService(
tts = RimeTTSService(
api_key=os.getenv("RIME_API_KEY", ""),
voice_id="rex",
params=RimeHttpTTSService.InputParams(reduce_latency=True),
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")

View File

@@ -73,6 +73,7 @@ openai = [ "openai~=1.59.6", "websockets~=13.1", "python-deepcompare~=2.1.0" ]
openpipe = [ "openpipe~=4.45.0" ]
perplexity = [ "openai~=1.59.6" ]
playht = [ "pyht~=0.1.6", "websockets~=13.1" ]
rime = [ "websockets~=13.1" ]
riva = [ "nvidia-riva-client~=2.18.0" ]
sentry = [ "sentry-sdk~=2.20.0" ]
silero = [ "onnxruntime~=1.20.1" ]

View File

@@ -4,20 +4,344 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import AsyncGenerator, Optional
import base64
import json
import uuid
from typing import AsyncGenerator, Optional, Union
import aiohttp
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.ai_services import TTSService
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService, WordTTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.transcriptions.language import Language
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Rime, you need to `pip install pipecat-ai[rime]`. Also, set `RIME_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
def language_to_rime_language(language: Language) -> str:
"""Convert pipecat Language to Rime language code.
Args:
language: The pipecat Language enum value.
Returns:
str: Three-letter language code used by Rime (e.g., 'eng' for English).
"""
LANGUAGE_MAP = {
Language.EN: "eng",
Language.ES: "spa",
}
return LANGUAGE_MAP.get(language, "eng")
class RimeTTSService(WordTTSService, WebsocketService):
"""Text-to-Speech service using Rime's websocket API.
Uses Rime's websocket JSON API to convert text to speech with word-level timing
information. Supports interruptions and maintains context across multiple messages
within a turn.
"""
class InputParams(BaseModel):
"""Configuration parameters for Rime TTS service."""
language: Optional[Language] = Language.EN
speed_alpha: Optional[float] = 1.0
reduce_latency: Optional[bool] = False
def __init__(
self,
*,
api_key: str,
voice_id: str,
url: str = "wss://users-ws.rime.ai/ws2",
model: str = "mistv2",
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),
**kwargs,
):
"""Initialize Rime TTS service.
Args:
api_key: Rime API key for authentication.
voice_id: ID of the voice to use.
url: Rime websocket API endpoint.
model: Model ID to use for synthesis.
sample_rate: Audio sample rate in Hz.
params: Additional configuration parameters.
"""
# Initialize with parent class settings for proper frame handling
WordTTSService.__init__(
self,
aggregate_sentences=True,
push_text_frames=False,
push_stop_frames=True,
stop_frame_timeout_s=2.0,
sample_rate=sample_rate,
**kwargs,
)
WebsocketService.__init__(self)
# Store service configuration
self._api_key = api_key
self._url = url
self._voice_id = voice_id
self._model = model
self._settings = {
"speaker": voice_id,
"modelId": model,
"audioFormat": "pcm",
"samplingRate": 0,
"lang": self.language_to_service_language(params.language)
if params.language
else "eng",
"speedAlpha": params.speed_alpha,
"reduceLatency": params.reduce_latency,
}
# State tracking
self._context_id = None # Tracks current turn
self._receive_task = None
self._started = False
self._cumulative_time = 0 # Accumulates time across messages
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
"""Convert pipecat language to Rime language code."""
return language_to_rime_language(language)
async def set_model(self, model: str):
"""Update the TTS model."""
self._model = model
await super().set_model(model)
def _build_msg(self, text: str = "") -> dict:
"""Build JSON message for Rime API."""
return {"text": text, "contextId": self._context_id}
def _build_clear_msg(self) -> dict:
"""Build clear operation message."""
return {"operation": "clear"}
def _build_eos_msg(self) -> dict:
"""Build end-of-stream operation message."""
return {"operation": "eos"}
async def start(self, frame: StartFrame):
"""Start the service and establish websocket connection."""
await super().start(frame)
self._settings["samplingRate"] = self.sample_rate
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the service and close connection."""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel current operation and clean up."""
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
"""Establish websocket connection and start receive task."""
await self._connect_websocket()
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
async def _disconnect(self):
"""Close websocket connection and clean up tasks."""
await self._disconnect_websocket()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
async def _connect_websocket(self):
"""Connect to Rime websocket API with configured settings."""
try:
params = "&".join(f"{k}={v}" for k, v in self._settings.items())
url = f"{self._url}?{params}"
headers = {"Authorization": f"Bearer {self._api_key}"}
self._websocket = await websockets.connect(url, extra_headers=headers)
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
async def _disconnect_websocket(self):
"""Close websocket connection and reset state."""
try:
await self.stop_all_metrics()
if self._websocket:
await self._websocket.send(json.dumps(self._build_eos_msg()))
await self._websocket.close()
self._websocket = None
self._started = False
self._context_id = None
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
def _get_websocket(self):
"""Get active websocket connection or raise exception."""
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
"""Handle interruption by clearing current context."""
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
if self._context_id:
await self._get_websocket().send(json.dumps(self._build_clear_msg()))
self._started = False
self._context_id = None
def _calculate_word_times(self, words: list, starts: list, ends: list) -> list:
"""Calculate word timing pairs with proper spacing and punctuation.
Args:
words: List of words from Rime.
starts: List of start times for each word.
ends: List of end times for each word.
Returns:
List of (word, timestamp) pairs with proper timing.
"""
word_pairs = []
for i, (word, start_time, _) in enumerate(zip(words, starts, ends)):
if not word.strip():
continue
# Adjust timing by adding cumulative time
adjusted_start = start_time + self._cumulative_time
# Handle punctuation by appending to previous word
is_punctuation = bool(word.strip(",.!?") == "")
if is_punctuation and word_pairs:
prev_word, prev_time = word_pairs[-1]
word_pairs[-1] = (prev_word + word, prev_time)
else:
word_pairs.append((word, adjusted_start))
return word_pairs
async def _receive_messages(self):
"""Process incoming websocket messages."""
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["contextId"] != self._context_id:
continue
if msg["type"] == "chunk":
# Process audio chunk
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self.sample_rate,
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "timestamps":
# Process word timing information
timestamps = msg.get("word_timestamps", {})
words = timestamps.get("words", [])
starts = timestamps.get("start", [])
ends = timestamps.get("end", [])
if words and starts:
# Calculate word timing pairs
word_pairs = self._calculate_word_times(words, starts, ends)
if word_pairs:
await self.add_word_timestamps(word_pairs)
self._cumulative_time = ends[-1] + self._cumulative_time
logger.debug(f"Updated cumulative time to: {self._cumulative_time}")
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push frame and handle end-of-turn conditions."""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
self._started = False
if isinstance(frame, TTSStoppedFrame):
await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)])
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames and manage turn state."""
await super().process_frame(frame, direction)
if isinstance(frame, TTSSpeakFrame):
await self.pause_processing_frames()
elif isinstance(frame, LLMFullResponseEndFrame) and self._started:
await self.pause_processing_frames()
elif isinstance(frame, BotStoppedSpeakingFrame):
await self.resume_processing_frames()
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text.
Args:
text: The text to convert to speech.
Yields:
Frames containing audio data and timing information.
"""
logger.debug(f"Generating TTS: [{text}]")
try:
if not self._websocket:
await self._connect()
try:
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True
self._cumulative_time = 0
self._context_id = str(uuid.uuid4())
msg = self._build_msg(text=text)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
class RimeHttpTTSService(TTSService):
@@ -33,7 +357,7 @@ class RimeHttpTTSService(TTSService):
*,
api_key: str,
voice_id: str = "eva",
model: str = "mist",
model: str = "mistv2",
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),
**kwargs,