From 9e16e3d614ce9ee4eef72dd7d744dbb584c836f6 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Wed, 7 May 2025 11:00:55 -0400 Subject: [PATCH] Update ElevenLabsTTSService to use the new websocket API --- CHANGELOG.md | 4 + src/pipecat/services/elevenlabs/tts.py | 120 +++++++++++++++++-------- 2 files changed, 85 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 101cc7f58..0bde64116 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Updated `ElevenLabsTTSService` to use the beta websocket API + (multi-stream-input). This new API supports context_ids and cancelling those + contexts, which greatly improves interruption handling. + - Observers `on_push_frame()` now take a single argument `FramePushed` instead of multiple arguments. diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index 0a3d5d0d1..324e8099e 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -7,11 +7,12 @@ import asyncio import base64 import json +import uuid from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Tuple, Union import aiohttp from loguru import logger -from pydantic import BaseModel, model_validator +from pydantic import BaseModel from pipecat.frames.frames import ( CancelFrame, @@ -26,7 +27,10 @@ from pipecat.frames.frames import ( TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import InterruptibleWordTTSService, WordTTSService +from pipecat.services.tts_service import ( + AudioContextWordTTSService, + WordTTSService, +) from pipecat.transcriptions.language import Language # See .env.example for ElevenLabs configuration needed @@ -159,10 +163,9 @@ def calculate_word_times( return word_times -class ElevenLabsTTSService(InterruptibleWordTTSService): +class ElevenLabsTTSService(AudioContextWordTTSService): class InputParams(BaseModel): language: Optional[Language] = None - optimize_streaming_latency: Optional[str] = None stability: Optional[float] = None similarity_boost: Optional[float] = None style: Optional[float] = None @@ -172,16 +175,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): enable_ssml_parsing: Optional[bool] = None enable_logging: Optional[bool] = None - @model_validator(mode="after") - def validate_voice_settings(self): - stability = self.stability - similarity_boost = self.similarity_boost - if (stability is None) != (similarity_boost is None): - raise ValueError( - "Both 'stability' and 'similarity_boost' must be provided when using voice settings" - ) - return self - def __init__( self, *, @@ -222,7 +215,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): "language": self.language_to_service_language(params.language) if params.language else None, - "optimize_streaming_latency": params.optimize_streaming_latency, "stability": params.stability, "similarity_boost": params.similarity_boost, "style": params.style, @@ -242,6 +234,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): self._started = False self._cumulative_time = 0 + # Context management for v1 multi API + self._context_id = None self._receive_task = None self._keepalive_task = None @@ -257,15 +251,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): async def set_model(self, model: str): await super().set_model(model) logger.info(f"Switching TTS model to: [{model}]") - await self._disconnect() - await self._connect() + # No need to disconnect/reconnect for model changes with multi-context API async def _update_settings(self, settings: Mapping[str, Any]): prev_voice = self._voice_id await super()._update_settings(settings) + # If voice changes, we don't need to reconnect, just use a new context if not prev_voice == self._voice_id: - await self._disconnect() - await self._connect() logger.info(f"Switching TTS voice to: [{self._voice_id}]") async def start(self, frame: StartFrame): @@ -282,8 +274,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): await self._disconnect() async def flush_audio(self): - if self._websocket: - msg = {"text": " ", "flush": True} + if self._websocket and self._context_id: + msg = {"context_id": self._context_id, "flush": True} await self._websocket.send(json.dumps(msg)) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): @@ -323,10 +315,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): voice_id = self._voice_id model = self.model_name output_format = self._output_format - url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}" - - if self._settings["optimize_streaming_latency"]: - url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}" + url = f"{self._url}/v1/text-to-speech/{voice_id}/multi-stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}" if self._settings["enable_ssml_parsing"]: url += f"&enable_ssml_parsing={self._settings['enable_ssml_parsing']}" @@ -347,14 +336,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): # Set max websocket message size to 16MB for large audio responses self._websocket = await websockets.connect(url, max_size=16 * 1024 * 1024) - # According to ElevenLabs, we should always start with a single space. - msg: Dict[str, Any] = { - "text": " ", - "xi_api_key": self._api_key, - } - if self._voice_settings: - msg["voice_settings"] = self._voice_settings - await self._websocket.send(json.dumps(msg)) except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -366,12 +347,15 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): if self._websocket: logger.debug("Disconnecting from ElevenLabs") - await self._websocket.send(json.dumps({"text": ""})) + # Close all contexts and the socket + if self._context_id: + await self._websocket.send(json.dumps({"close_socket": True})) await self._websocket.close() except Exception as e: logger.error(f"{self} error closing websocket: {e}") finally: self._started = False + self._context_id = None self._websocket = None def _get_websocket(self): @@ -379,9 +363,35 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): return self._websocket raise Exception("Websocket not connected") + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + + # Close the current context when interrupted without closing the websocket + if self._context_id and self._websocket: + logger.trace(f"Closing context {self._context_id} due to interruption") + try: + await self._websocket.send( + json.dumps({"context_id": self._context_id, "close_context": True}) + ) + except Exception as e: + logger.error(f"Error closing context on interruption: {e}") + self._context_id = None + self._started = False + async def _receive_messages(self): async for message in self._get_websocket(): msg = json.loads(message) + # Check if this message belongs to the current context + # The default context may return null/None for context_id + received_ctx_id = msg.get("context_id") + if ( + self._context_id is not None + and received_ctx_id is not None + and received_ctx_id != self._context_id + ): + logger.trace(f"Ignoring message from different context: {received_ctx_id}") + continue + if msg.get("audio"): await self.stop_ttfb_metrics() self.start_word_timestamps() @@ -393,20 +403,45 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): word_times = calculate_word_times(msg["alignment"], self._cumulative_time) await self.add_word_timestamps(word_times) self._cumulative_time = word_times[-1][1] + if msg.get("is_final"): + logger.trace(f"Received final message for context {received_ctx_id}") + # Context has finished + if self._context_id == received_ctx_id: + self._context_id = None + self._started = False async def _keepalive_task_handler(self): while True: await asyncio.sleep(10) try: - await self._send_text("") + # Send an empty message to keep the connection alive + if self._websocket and self._websocket.open: + await self._websocket.send(json.dumps({})) except websockets.ConnectionClosed as e: logger.warning(f"{self} keepalive error: {e}") break async def _send_text(self, text: str): if self._websocket: - msg = {"text": text + " "} - await self._websocket.send(json.dumps(msg)) + if not self._context_id: + # First message for a new context - need a space to initialize + msg = {"text": " ", "context_id": str(uuid.uuid4()), "xi_api_key": self._api_key} + + # Add voice settings only in first message for a context + if self._voice_settings: + msg["voice_settings"] = self._voice_settings + + await self._websocket.send(json.dumps(msg)) + self._context_id = msg["context_id"] + logger.trace(f"Created new context {self._context_id}") + + # Now send the actual text content + msg = {"text": text, "context_id": self._context_id} + await self._websocket.send(json.dumps(msg)) + else: + # Continuing with an existing context + msg = {"text": text, "context_id": self._context_id} + await self._websocket.send(json.dumps(msg)) async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"{self}: Generating TTS [{text}]") @@ -416,6 +451,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): await self._connect() try: + # Close previous context if there was one + if self._context_id and not self._started: + await self._websocket.send( + json.dumps({"context_id": self._context_id, "close_context": True}) + ) + self._context_id = None + if not self._started: await self.start_ttfb_metrics() yield TTSStartedFrame() @@ -427,8 +469,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService): except Exception as e: logger.error(f"{self} error sending message: {e}") yield TTSStoppedFrame() - await self._disconnect() - await self._connect() + self._started = False + self._context_id = None return yield None except Exception as e: