diff --git a/CHANGELOG.md b/CHANGELOG.md index 217474f5f..848123fd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `RimeTTSService`, a new `WordTTSService`. Updated the foundational + example `07q-interruptible-rime.py` to use `RimeTTSService`. + - Added support for Groq's Whisper API through the new `GroqSTTService` and OpenAI's Whisper API through the new `OpenAISTTService`. Introduced a new base class `BaseWhisperSTTService` to handle common Whisper API @@ -22,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- `RimeHttpTTSService` now uses the `mistv2` model by default. + - Improved error handling in `AzureTTSService` to properly detect and log synthesis cancellation errors. diff --git a/examples/foundational/07q-interruptible-rime.py b/examples/foundational/07q-interruptible-rime.py index 2465e2596..7bbc4cb25 100644 --- a/examples/foundational/07q-interruptible-rime.py +++ b/examples/foundational/07q-interruptible-rime.py @@ -19,7 +19,7 @@ from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.openai import OpenAILLMService -from pipecat.services.rime import RimeHttpTTSService +from pipecat.services.rime import RimeTTSService from pipecat.transports.services.daily import DailyParams, DailyTransport load_dotenv(override=True) @@ -44,10 +44,9 @@ async def main(): ), ) - tts = RimeHttpTTSService( + tts = RimeTTSService( api_key=os.getenv("RIME_API_KEY", ""), voice_id="rex", - params=RimeHttpTTSService.InputParams(reduce_latency=True), ) llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") diff --git a/pyproject.toml b/pyproject.toml index b8328dd1d..82af4df2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ openai = [ "openai~=1.59.6", "websockets~=13.1", "python-deepcompare~=2.1.0" ] openpipe = [ "openpipe~=4.45.0" ] perplexity = [ "openai~=1.59.6" ] playht = [ "pyht~=0.1.6", "websockets~=13.1" ] +rime = [ "websockets~=13.1" ] riva = [ "nvidia-riva-client~=2.18.0" ] sentry = [ "sentry-sdk~=2.20.0" ] silero = [ "onnxruntime~=1.20.1" ] diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index 51bc7253d..0210f8de6 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -4,20 +4,344 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import AsyncGenerator, Optional +import base64 +import json +import uuid +from typing import AsyncGenerator, Optional, Union import aiohttp from loguru import logger from pydantic import BaseModel from pipecat.frames.frames import ( + BotStoppedSpeakingFrame, + CancelFrame, + EndFrame, ErrorFrame, Frame, + LLMFullResponseEndFrame, + StartFrame, + StartInterruptionFrame, TTSAudioRawFrame, + TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, ) -from pipecat.services.ai_services import TTSService +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import TTSService, WordTTSService +from pipecat.services.websocket_service import WebsocketService +from pipecat.transcriptions.language import Language + +try: + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Rime, you need to `pip install pipecat-ai[rime]`. Also, set `RIME_API_KEY` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +def language_to_rime_language(language: Language) -> str: + """Convert pipecat Language to Rime language code. + + Args: + language: The pipecat Language enum value. + + Returns: + str: Three-letter language code used by Rime (e.g., 'eng' for English). + """ + LANGUAGE_MAP = { + Language.EN: "eng", + Language.ES: "spa", + } + return LANGUAGE_MAP.get(language, "eng") + + +class RimeTTSService(WordTTSService, WebsocketService): + """Text-to-Speech service using Rime's websocket API. + + Uses Rime's websocket JSON API to convert text to speech with word-level timing + information. Supports interruptions and maintains context across multiple messages + within a turn. + """ + + class InputParams(BaseModel): + """Configuration parameters for Rime TTS service.""" + + language: Optional[Language] = Language.EN + speed_alpha: Optional[float] = 1.0 + reduce_latency: Optional[bool] = False + + def __init__( + self, + *, + api_key: str, + voice_id: str, + url: str = "wss://users-ws.rime.ai/ws2", + model: str = "mistv2", + sample_rate: Optional[int] = None, + params: InputParams = InputParams(), + **kwargs, + ): + """Initialize Rime TTS service. + + Args: + api_key: Rime API key for authentication. + voice_id: ID of the voice to use. + url: Rime websocket API endpoint. + model: Model ID to use for synthesis. + sample_rate: Audio sample rate in Hz. + params: Additional configuration parameters. + """ + # Initialize with parent class settings for proper frame handling + WordTTSService.__init__( + self, + aggregate_sentences=True, + push_text_frames=False, + push_stop_frames=True, + stop_frame_timeout_s=2.0, + sample_rate=sample_rate, + **kwargs, + ) + WebsocketService.__init__(self) + + # Store service configuration + self._api_key = api_key + self._url = url + self._voice_id = voice_id + self._model = model + self._settings = { + "speaker": voice_id, + "modelId": model, + "audioFormat": "pcm", + "samplingRate": 0, + "lang": self.language_to_service_language(params.language) + if params.language + else "eng", + "speedAlpha": params.speed_alpha, + "reduceLatency": params.reduce_latency, + } + + # State tracking + self._context_id = None # Tracks current turn + self._receive_task = None + self._started = False + self._cumulative_time = 0 # Accumulates time across messages + + def can_generate_metrics(self) -> bool: + return True + + def language_to_service_language(self, language: Language) -> str | None: + """Convert pipecat language to Rime language code.""" + return language_to_rime_language(language) + + async def set_model(self, model: str): + """Update the TTS model.""" + self._model = model + await super().set_model(model) + + def _build_msg(self, text: str = "") -> dict: + """Build JSON message for Rime API.""" + return {"text": text, "contextId": self._context_id} + + def _build_clear_msg(self) -> dict: + """Build clear operation message.""" + return {"operation": "clear"} + + def _build_eos_msg(self) -> dict: + """Build end-of-stream operation message.""" + return {"operation": "eos"} + + async def start(self, frame: StartFrame): + """Start the service and establish websocket connection.""" + await super().start(frame) + self._settings["samplingRate"] = self.sample_rate + await self._connect() + + async def stop(self, frame: EndFrame): + """Stop the service and close connection.""" + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + """Cancel current operation and clean up.""" + await super().cancel(frame) + await self._disconnect() + + async def _connect(self): + """Establish websocket connection and start receive task.""" + await self._connect_websocket() + self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + + async def _disconnect(self): + """Close websocket connection and clean up tasks.""" + await self._disconnect_websocket() + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + async def _connect_websocket(self): + """Connect to Rime websocket API with configured settings.""" + try: + params = "&".join(f"{k}={v}" for k, v in self._settings.items()) + url = f"{self._url}?{params}" + headers = {"Authorization": f"Bearer {self._api_key}"} + self._websocket = await websockets.connect(url, extra_headers=headers) + except Exception as e: + logger.error(f"{self} initialization error: {e}") + self._websocket = None + + async def _disconnect_websocket(self): + """Close websocket connection and reset state.""" + try: + await self.stop_all_metrics() + if self._websocket: + await self._websocket.send(json.dumps(self._build_eos_msg())) + await self._websocket.close() + self._websocket = None + self._started = False + self._context_id = None + except Exception as e: + logger.error(f"{self} error closing websocket: {e}") + + def _get_websocket(self): + """Get active websocket connection or raise exception.""" + if self._websocket: + return self._websocket + raise Exception("Websocket not connected") + + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + """Handle interruption by clearing current context.""" + await super()._handle_interruption(frame, direction) + await self.stop_all_metrics() + if self._context_id: + await self._get_websocket().send(json.dumps(self._build_clear_msg())) + self._started = False + self._context_id = None + + def _calculate_word_times(self, words: list, starts: list, ends: list) -> list: + """Calculate word timing pairs with proper spacing and punctuation. + + Args: + words: List of words from Rime. + starts: List of start times for each word. + ends: List of end times for each word. + + Returns: + List of (word, timestamp) pairs with proper timing. + """ + word_pairs = [] + for i, (word, start_time, _) in enumerate(zip(words, starts, ends)): + if not word.strip(): + continue + + # Adjust timing by adding cumulative time + adjusted_start = start_time + self._cumulative_time + + # Handle punctuation by appending to previous word + is_punctuation = bool(word.strip(",.!?") == "") + if is_punctuation and word_pairs: + prev_word, prev_time = word_pairs[-1] + word_pairs[-1] = (prev_word + word, prev_time) + else: + word_pairs.append((word, adjusted_start)) + + return word_pairs + + async def _receive_messages(self): + """Process incoming websocket messages.""" + async for message in self._get_websocket(): + msg = json.loads(message) + + if not msg or msg["contextId"] != self._context_id: + continue + + if msg["type"] == "chunk": + # Process audio chunk + await self.stop_ttfb_metrics() + self.start_word_timestamps() + frame = TTSAudioRawFrame( + audio=base64.b64decode(msg["data"]), + sample_rate=self.sample_rate, + num_channels=1, + ) + await self.push_frame(frame) + + elif msg["type"] == "timestamps": + # Process word timing information + timestamps = msg.get("word_timestamps", {}) + words = timestamps.get("words", []) + starts = timestamps.get("start", []) + ends = timestamps.get("end", []) + + if words and starts: + # Calculate word timing pairs + word_pairs = self._calculate_word_times(words, starts, ends) + if word_pairs: + await self.add_word_timestamps(word_pairs) + self._cumulative_time = ends[-1] + self._cumulative_time + logger.debug(f"Updated cumulative time to: {self._cumulative_time}") + + elif msg["type"] == "error": + logger.error(f"{self} error: {msg}") + await self.push_frame(TTSStoppedFrame()) + await self.stop_all_metrics() + await self.push_error(ErrorFrame(f"{self} error: {msg['message']}")) + + async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): + """Push frame and handle end-of-turn conditions.""" + await super().push_frame(frame, direction) + if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)): + self._started = False + if isinstance(frame, TTSStoppedFrame): + await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)]) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and manage turn state.""" + await super().process_frame(frame, direction) + + if isinstance(frame, TTSSpeakFrame): + await self.pause_processing_frames() + elif isinstance(frame, LLMFullResponseEndFrame) and self._started: + await self.pause_processing_frames() + elif isinstance(frame, BotStoppedSpeakingFrame): + await self.resume_processing_frames() + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + """Generate speech from text. + + Args: + text: The text to convert to speech. + + Yields: + Frames containing audio data and timing information. + """ + logger.debug(f"Generating TTS: [{text}]") + try: + if not self._websocket: + await self._connect() + + try: + if not self._started: + await self.start_ttfb_metrics() + yield TTSStartedFrame() + self._started = True + self._cumulative_time = 0 + self._context_id = str(uuid.uuid4()) + + msg = self._build_msg(text=text) + await self._get_websocket().send(json.dumps(msg)) + await self.start_tts_usage_metrics(text) + except Exception as e: + logger.error(f"{self} error sending message: {e}") + yield TTSStoppedFrame() + await self._disconnect() + await self._connect() + return + yield None + except Exception as e: + logger.error(f"{self} exception: {e}") class RimeHttpTTSService(TTSService): @@ -33,7 +357,7 @@ class RimeHttpTTSService(TTSService): *, api_key: str, voice_id: str = "eva", - model: str = "mist", + model: str = "mistv2", sample_rate: Optional[int] = None, params: InputParams = InputParams(), **kwargs,