From 8f8a3ae7f9040398595fda7262acde8557ef9449 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Thu, 6 Feb 2025 16:17:38 -0500 Subject: [PATCH 1/5] Add RimeTTSService --- CHANGELOG.md | 3 + .../foundational/07q-interruptible-rime.py | 5 +- pyproject.toml | 1 + src/pipecat/services/rime.py | 322 +++++++++++++++++- 4 files changed, 326 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7869da8b2..48d6773f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `RimeTTSService`, a new `WordTTSService`. Updated the foundational + example `07q-interruptible-rime.py` to use `RimeTTSService`. + - Added support for Groq's Whisper API through the new `GroqSTTService` and OpenAI's Whisper API through the new `OpenAISTTService`. Introduced a new base class `BaseWhisperSTTService` to handle common Whisper API diff --git a/examples/foundational/07q-interruptible-rime.py b/examples/foundational/07q-interruptible-rime.py index 2465e2596..7bbc4cb25 100644 --- a/examples/foundational/07q-interruptible-rime.py +++ b/examples/foundational/07q-interruptible-rime.py @@ -19,7 +19,7 @@ from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.openai import OpenAILLMService -from pipecat.services.rime import RimeHttpTTSService +from pipecat.services.rime import RimeTTSService from pipecat.transports.services.daily import DailyParams, DailyTransport load_dotenv(override=True) @@ -44,10 +44,9 @@ async def main(): ), ) - tts = RimeHttpTTSService( + tts = RimeTTSService( api_key=os.getenv("RIME_API_KEY", ""), voice_id="rex", - params=RimeHttpTTSService.InputParams(reduce_latency=True), ) llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") diff --git a/pyproject.toml b/pyproject.toml index 46b64c6ba..8d26eab2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ openai = [ "openai~=1.59.6", "websockets~=13.1", "python-deepcompare~=2.1.0" ] openpipe = [ "openpipe~=4.45.0" ] perplexity = [ "openai~=1.59.6" ] playht = [ "pyht~=0.1.6", "websockets~=13.1" ] +rime = [ "websockets~=13.1" ] riva = [ "nvidia-riva-client~=2.18.0" ] sentry = [ "sentry-sdk~=2.20.0" ] silero = [ "onnxruntime~=1.20.1" ] diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index 51bc7253d..8e903b395 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -4,20 +4,338 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import AsyncGenerator, Optional +import base64 +import json +import uuid +from typing import AsyncGenerator, Optional, Union import aiohttp from loguru import logger from pydantic import BaseModel from pipecat.frames.frames import ( + BotStoppedSpeakingFrame, + CancelFrame, + EndFrame, ErrorFrame, Frame, + LLMFullResponseEndFrame, + StartFrame, + StartInterruptionFrame, TTSAudioRawFrame, + TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, ) -from pipecat.services.ai_services import TTSService +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import TTSService, WordTTSService +from pipecat.services.websocket_service import WebsocketService +from pipecat.transcriptions.language import Language + +try: + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Rime, you need to `pip install pipecat-ai[rime]`. Also, set `RIME_API_KEY` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +def language_to_rime_language(language: Language) -> str: + """Convert pipecat Language to Rime language code. + + Args: + language: The pipecat Language enum value. + + Returns: + str: Three-letter language code used by Rime (e.g., 'eng' for English). + """ + LANGUAGE_MAP = { + Language.EN: "eng", + Language.ES: "spa", + } + return LANGUAGE_MAP.get(language, "eng") + + +class RimeTTSService(WordTTSService, WebsocketService): + """Text-to-Speech service using Rime's websocket API. + + Uses Rime's websocket JSON API to convert text to speech with word-level timing + information. Supports interruptions and maintains context across multiple messages + within a turn. + """ + + class InputParams(BaseModel): + """Configuration parameters for Rime TTS service.""" + + language: Optional[Language] = Language.EN + speed_alpha: Optional[float] = 1.0 + reduce_latency: Optional[bool] = False + + def __init__( + self, + *, + api_key: str, + voice_id: str, + url: str = "wss://users-ws.rime.ai/ws2", + model: str = "mistv2", + sample_rate: Optional[int] = None, + params: InputParams = InputParams(), + **kwargs, + ): + """Initialize Rime TTS service. + + Args: + api_key: Rime API key for authentication. + voice_id: ID of the voice to use. + url: Rime websocket API endpoint. + model: Model ID to use for synthesis. + sample_rate: Audio sample rate in Hz. + params: Additional configuration parameters. + """ + # Initialize with parent class settings for proper frame handling + WordTTSService.__init__( + self, + aggregate_sentences=True, + push_text_frames=False, + push_stop_frames=True, + stop_frame_timeout_s=2.0, + sample_rate=sample_rate, + **kwargs, + ) + WebsocketService.__init__(self) + + # Store service configuration + self._api_key = api_key + self._url = url + self._voice_id = voice_id + self._model = model + self._settings = { + "speaker": voice_id, + "modelId": model, + "audioFormat": "pcm", + "samplingRate": sample_rate, + "lang": self.language_to_service_language(params.language) + if params.language + else "eng", + "speedAlpha": params.speed_alpha, + "reduceLatency": params.reduce_latency, + } + + # State tracking + self._context_id = None # Tracks current turn + self._receive_task = None + self._started = False + self._cumulative_time = 0 # Accumulates time across messages + + def can_generate_metrics(self) -> bool: + return True + + def language_to_service_language(self, language: Language) -> str | None: + """Convert pipecat language to Rime language code.""" + return language_to_rime_language(language) + + async def set_model(self, model: str): + """Update the TTS model.""" + self._model = model + await super().set_model(model) + + def _build_msg(self, text: str = "") -> dict: + """Build JSON message for Rime API.""" + return {"text": text, "contextId": self._context_id} + + def _build_clear_msg(self) -> dict: + """Build clear operation message.""" + return {"operation": "clear"} + + def _build_eos_msg(self) -> dict: + """Build end-of-stream operation message.""" + return {"operation": "eos"} + + async def start(self, frame: StartFrame): + """Start the service and establish websocket connection.""" + await super().start(frame) + self._settings["samplingRate"] = self.sample_rate + await self._connect() + + async def stop(self, frame: EndFrame): + """Stop the service and close connection.""" + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + """Cancel current operation and clean up.""" + await super().cancel(frame) + await self._disconnect() + + async def _connect(self): + """Establish websocket connection and start receive task.""" + await self._connect_websocket() + self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + + async def _disconnect(self): + """Close websocket connection and clean up tasks.""" + await self._disconnect_websocket() + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + async def _connect_websocket(self): + """Connect to Rime websocket API with configured settings.""" + try: + settings = {k: str(v) for k, v in self._settings.items() if v is not None} + params = "&".join(f"{k}={v}" for k, v in settings.items()) + url = f"{self._url}?{params}" + headers = {"Authorization": f"Bearer {self._api_key}"} + self._websocket = await websockets.connect(url, extra_headers=headers) + except Exception as e: + logger.error(f"{self} initialization error: {e}") + self._websocket = None + + async def _disconnect_websocket(self): + """Close websocket connection and reset state.""" + try: + await self.stop_all_metrics() + if self._websocket: + await self._websocket.send(json.dumps(self._build_eos_msg())) + await self._websocket.close() + self._websocket = None + self._started = False + self._context_id = None + except Exception as e: + logger.error(f"{self} error closing websocket: {e}") + + def _get_websocket(self): + """Get active websocket connection or raise exception.""" + if self._websocket: + return self._websocket + raise Exception("Websocket not connected") + + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + """Handle interruption by clearing current context.""" + await super()._handle_interruption(frame, direction) + await self.stop_all_metrics() + if self._context_id: + await self._get_websocket().send(json.dumps(self._build_clear_msg())) + self._started = False + self._context_id = None + + async def _receive_messages(self): + """Process incoming websocket messages. + + Handles audio chunks and word timestamps, maintaining proper timing and + text alignment for the current context. + """ + async for message in self._get_websocket(): + msg = json.loads(message) + if not msg or msg["contextId"] != self._context_id: + continue + + if msg["type"] == "chunk": + # Process audio chunk + await self.stop_ttfb_metrics() + self.start_word_timestamps() + frame = TTSAudioRawFrame( + audio=base64.b64decode(msg["data"]), + sample_rate=self.sample_rate, + num_channels=1, + ) + await self.push_frame(frame) + + elif msg["type"] == "timestamps": + # Process word timing information + timestamps = msg.get("word_timestamps", {}) + words = timestamps.get("words", []) + starts = timestamps.get("start", []) + ends = timestamps.get("end", []) + + if words and starts: + word_pairs = [] + for i, (word, start_time, end_time) in enumerate(zip(words, starts, ends)): + if not word.strip(): + continue + + # Adjust timing by adding cumulative time + adjusted_start = start_time + self._cumulative_time + + # Handle spacing and punctuation + is_punctuation = bool(word.strip(",.!?") == "") + if is_punctuation: + # Append punctuation to previous word + if word_pairs: + prev_word, prev_time = word_pairs[-1] + word_pairs[-1] = (prev_word + word, prev_time) + else: + # Add space between words (not before punctuation) + needs_space = word_pairs and not words[i - 1].strip(",.!?") == "" + if needs_space: + word = " " + word + word_pairs.append((word, adjusted_start)) + + if word_pairs: + await self.add_word_timestamps(word_pairs) + self._cumulative_time = ends[-1] + self._cumulative_time + + elif msg["type"] == "error": + logger.error(f"{self} error: {msg}") + await self.push_frame(TTSStoppedFrame()) + await self.stop_all_metrics() + await self.push_error(ErrorFrame(f"{self} error: {msg['message']}")) + + async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): + """Push frame and handle end-of-turn conditions.""" + await super().push_frame(frame, direction) + if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)): + self._started = False + if isinstance(frame, TTSStoppedFrame): + await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)]) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and manage turn state.""" + await super().process_frame(frame, direction) + + if isinstance(frame, TTSSpeakFrame): + await self.pause_processing_frames() + elif isinstance(frame, LLMFullResponseEndFrame) and self._started: + await self.pause_processing_frames() + elif isinstance(frame, BotStoppedSpeakingFrame): + await self.resume_processing_frames() + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + """Generate speech from text. + + Args: + text: The text to convert to speech. + + Yields: + Frames containing audio data and timing information. + """ + logger.debug(f"Generating TTS: [{text}]") + try: + if not self._websocket: + await self._connect() + + try: + if not self._started: + await self.start_ttfb_metrics() + yield TTSStartedFrame() + self._started = True + self._cumulative_time = 0 + self._context_id = str(uuid.uuid4()) + + msg = self._build_msg(text=text) + await self._get_websocket().send(json.dumps(msg)) + await self.start_tts_usage_metrics(text) + except Exception as e: + logger.error(f"{self} error sending message: {e}") + yield TTSStoppedFrame() + await self._disconnect() + await self._connect() + return + yield None + except Exception as e: + logger.error(f"{self} exception: {e}") class RimeHttpTTSService(TTSService): From 54f64b8dad9ad1f8d41183122003b716addaf691 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 7 Feb 2025 09:07:18 -0500 Subject: [PATCH 2/5] Code review feedback --- src/pipecat/services/rime.py | 68 +++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index 8e903b395..6cdec14dc 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -115,7 +115,7 @@ class RimeTTSService(WordTTSService, WebsocketService): "speaker": voice_id, "modelId": model, "audioFormat": "pcm", - "samplingRate": sample_rate, + "samplingRate": 0, "lang": self.language_to_service_language(params.language) if params.language else "eng", @@ -184,8 +184,7 @@ class RimeTTSService(WordTTSService, WebsocketService): async def _connect_websocket(self): """Connect to Rime websocket API with configured settings.""" try: - settings = {k: str(v) for k, v in self._settings.items() if v is not None} - params = "&".join(f"{k}={v}" for k, v in settings.items()) + params = "&".join(f"{k}={v}" for k, v in self._settings.items()) url = f"{self._url}?{params}" headers = {"Authorization": f"Bearer {self._api_key}"} self._websocket = await websockets.connect(url, extra_headers=headers) @@ -221,12 +220,43 @@ class RimeTTSService(WordTTSService, WebsocketService): self._started = False self._context_id = None - async def _receive_messages(self): - """Process incoming websocket messages. + def _calculate_word_times(self, words: list, starts: list, ends: list) -> list: + """Calculate word timing pairs with proper spacing and punctuation. - Handles audio chunks and word timestamps, maintaining proper timing and - text alignment for the current context. + Args: + words: List of words from Rime. + starts: List of start times for each word. + ends: List of end times for each word. + + Returns: + List of (word, timestamp) pairs with proper spacing and timing. """ + word_pairs = [] + for i, (word, start_time, end_time) in enumerate(zip(words, starts, ends)): + if not word.strip(): + continue + + # Adjust timing by adding cumulative time + adjusted_start = start_time + self._cumulative_time + + # Handle spacing and punctuation + is_punctuation = bool(word.strip(",.!?") == "") + if is_punctuation: + # Append punctuation to previous word + if word_pairs: + prev_word, prev_time = word_pairs[-1] + word_pairs[-1] = (prev_word + word, prev_time) + else: + # Add space between words (not before punctuation) + needs_space = word_pairs and not words[i - 1].strip(",.!?") == "" + if needs_space: + word = " " + word + word_pairs.append((word, adjusted_start)) + + return word_pairs + + async def _receive_messages(self): + """Process incoming websocket messages.""" async for message in self._get_websocket(): msg = json.loads(message) if not msg or msg["contextId"] != self._context_id: @@ -251,28 +281,8 @@ class RimeTTSService(WordTTSService, WebsocketService): ends = timestamps.get("end", []) if words and starts: - word_pairs = [] - for i, (word, start_time, end_time) in enumerate(zip(words, starts, ends)): - if not word.strip(): - continue - - # Adjust timing by adding cumulative time - adjusted_start = start_time + self._cumulative_time - - # Handle spacing and punctuation - is_punctuation = bool(word.strip(",.!?") == "") - if is_punctuation: - # Append punctuation to previous word - if word_pairs: - prev_word, prev_time = word_pairs[-1] - word_pairs[-1] = (prev_word + word, prev_time) - else: - # Add space between words (not before punctuation) - needs_space = word_pairs and not words[i - 1].strip(",.!?") == "" - if needs_space: - word = " " + word - word_pairs.append((word, adjusted_start)) - + # Calculate word timing pairs + word_pairs = self._calculate_word_times(words, starts, ends) if word_pairs: await self.add_word_timestamps(word_pairs) self._cumulative_time = ends[-1] + self._cumulative_time From 8020db350e89850ddb3e11ab187337e12fbc41f1 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 7 Feb 2025 09:20:28 -0500 Subject: [PATCH 3/5] Update RimeHttpTTSService to use mistv2 model by default --- CHANGELOG.md | 2 ++ src/pipecat/services/rime.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48d6773f4..15082f49d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- `RimeHttpTTSService` now uses the `mistv2` model by default. + - Improved error handling in `AzureTTSService` to properly detect and log synthesis cancellation errors. diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index 6cdec14dc..caf04df6b 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -361,7 +361,7 @@ class RimeHttpTTSService(TTSService): *, api_key: str, voice_id: str = "eva", - model: str = "mist", + model: str = "mistv2", sample_rate: Optional[int] = None, params: InputParams = InputParams(), **kwargs, From 97586b132d120e7e4ca4d532dd4b265f53df6e38 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 10 Feb 2025 18:22:46 -0500 Subject: [PATCH 4/5] Simplify _calculate_word_times --- src/pipecat/services/rime.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index caf04df6b..b05fe5441 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -229,7 +229,7 @@ class RimeTTSService(WordTTSService, WebsocketService): ends: List of end times for each word. Returns: - List of (word, timestamp) pairs with proper spacing and timing. + List of (word, timestamp) pairs with proper timing. """ word_pairs = [] for i, (word, start_time, end_time) in enumerate(zip(words, starts, ends)): @@ -239,18 +239,12 @@ class RimeTTSService(WordTTSService, WebsocketService): # Adjust timing by adding cumulative time adjusted_start = start_time + self._cumulative_time - # Handle spacing and punctuation + # Handle punctuation by appending to previous word is_punctuation = bool(word.strip(",.!?") == "") - if is_punctuation: - # Append punctuation to previous word - if word_pairs: - prev_word, prev_time = word_pairs[-1] - word_pairs[-1] = (prev_word + word, prev_time) + if is_punctuation and word_pairs: + prev_word, prev_time = word_pairs[-1] + word_pairs[-1] = (prev_word + word, prev_time) else: - # Add space between words (not before punctuation) - needs_space = word_pairs and not words[i - 1].strip(",.!?") == "" - if needs_space: - word = " " + word word_pairs.append((word, adjusted_start)) return word_pairs @@ -259,6 +253,7 @@ class RimeTTSService(WordTTSService, WebsocketService): """Process incoming websocket messages.""" async for message in self._get_websocket(): msg = json.loads(message) + if not msg or msg["contextId"] != self._context_id: continue @@ -286,6 +281,7 @@ class RimeTTSService(WordTTSService, WebsocketService): if word_pairs: await self.add_word_timestamps(word_pairs) self._cumulative_time = ends[-1] + self._cumulative_time + logger.debug(f"Updated cumulative time to: {self._cumulative_time}") elif msg["type"] == "error": logger.error(f"{self} error: {msg}") From 69b0d9035f7b89a538bac029cf926a0b7ed5ec0c Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 11 Feb 2025 17:44:52 -0500 Subject: [PATCH 5/5] Mark end_time as unused --- src/pipecat/services/rime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index b05fe5441..0210f8de6 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -232,7 +232,7 @@ class RimeTTSService(WordTTSService, WebsocketService): List of (word, timestamp) pairs with proper timing. """ word_pairs = [] - for i, (word, start_time, end_time) in enumerate(zip(words, starts, ends)): + for i, (word, start_time, _) in enumerate(zip(words, starts, ends)): if not word.strip(): continue