Merge pull request #1761 from pipecat-ai/mb/elevenlabs-context-id

Update ElevenLabsTTSService to use the new websocket API
This commit is contained in:
Mark Backman
2025-05-07 17:12:06 -04:00
committed by GitHub
2 changed files with 85 additions and 39 deletions

View File

@@ -28,6 +28,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Updated `ElevenLabsTTSService` to use the beta websocket API
(multi-stream-input). This new API supports context_ids and cancelling those
contexts, which greatly improves interruption handling.
- Observers `on_push_frame()` now take a single argument `FramePushed` instead
of multiple arguments.

View File

@@ -7,11 +7,12 @@
import asyncio
import base64
import json
import uuid
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Tuple, Union
import aiohttp
from loguru import logger
from pydantic import BaseModel, model_validator
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
@@ -26,7 +27,10 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import InterruptibleWordTTSService, WordTTSService
from pipecat.services.tts_service import (
AudioContextWordTTSService,
WordTTSService,
)
from pipecat.transcriptions.language import Language
# See .env.example for ElevenLabs configuration needed
@@ -159,10 +163,9 @@ def calculate_word_times(
return word_times
class ElevenLabsTTSService(InterruptibleWordTTSService):
class ElevenLabsTTSService(AudioContextWordTTSService):
class InputParams(BaseModel):
language: Optional[Language] = None
optimize_streaming_latency: Optional[str] = None
stability: Optional[float] = None
similarity_boost: Optional[float] = None
style: Optional[float] = None
@@ -172,16 +175,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
enable_ssml_parsing: Optional[bool] = None
enable_logging: Optional[bool] = None
@model_validator(mode="after")
def validate_voice_settings(self):
stability = self.stability
similarity_boost = self.similarity_boost
if (stability is None) != (similarity_boost is None):
raise ValueError(
"Both 'stability' and 'similarity_boost' must be provided when using voice settings"
)
return self
def __init__(
self,
*,
@@ -222,7 +215,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
"language": self.language_to_service_language(params.language)
if params.language
else None,
"optimize_streaming_latency": params.optimize_streaming_latency,
"stability": params.stability,
"similarity_boost": params.similarity_boost,
"style": params.style,
@@ -242,6 +234,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
self._started = False
self._cumulative_time = 0
# Context management for v1 multi API
self._context_id = None
self._receive_task = None
self._keepalive_task = None
@@ -257,15 +251,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
async def set_model(self, model: str):
await super().set_model(model)
logger.info(f"Switching TTS model to: [{model}]")
await self._disconnect()
await self._connect()
# No need to disconnect/reconnect for model changes with multi-context API
async def _update_settings(self, settings: Mapping[str, Any]):
prev_voice = self._voice_id
await super()._update_settings(settings)
# If voice changes, we don't need to reconnect, just use a new context
if not prev_voice == self._voice_id:
await self._disconnect()
await self._connect()
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
async def start(self, frame: StartFrame):
@@ -282,8 +274,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
await self._disconnect()
async def flush_audio(self):
if self._websocket:
msg = {"text": " ", "flush": True}
if self._websocket and self._context_id:
msg = {"context_id": self._context_id, "flush": True}
await self._websocket.send(json.dumps(msg))
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -323,10 +315,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
voice_id = self._voice_id
model = self.model_name
output_format = self._output_format
url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}"
if self._settings["optimize_streaming_latency"]:
url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}"
url = f"{self._url}/v1/text-to-speech/{voice_id}/multi-stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}"
if self._settings["enable_ssml_parsing"]:
url += f"&enable_ssml_parsing={self._settings['enable_ssml_parsing']}"
@@ -347,14 +336,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
# Set max websocket message size to 16MB for large audio responses
self._websocket = await websockets.connect(url, max_size=16 * 1024 * 1024)
# According to ElevenLabs, we should always start with a single space.
msg: Dict[str, Any] = {
"text": " ",
"xi_api_key": self._api_key,
}
if self._voice_settings:
msg["voice_settings"] = self._voice_settings
await self._websocket.send(json.dumps(msg))
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
@@ -366,12 +347,15 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
if self._websocket:
logger.debug("Disconnecting from ElevenLabs")
await self._websocket.send(json.dumps({"text": ""}))
# Close all contexts and the socket
if self._context_id:
await self._websocket.send(json.dumps({"close_socket": True}))
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
finally:
self._started = False
self._context_id = None
self._websocket = None
def _get_websocket(self):
@@ -379,9 +363,35 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
return self._websocket
raise Exception("Websocket not connected")
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
# Close the current context when interrupted without closing the websocket
if self._context_id and self._websocket:
logger.trace(f"Closing context {self._context_id} due to interruption")
try:
await self._websocket.send(
json.dumps({"context_id": self._context_id, "close_context": True})
)
except Exception as e:
logger.error(f"Error closing context on interruption: {e}")
self._context_id = None
self._started = False
async def _receive_messages(self):
async for message in self._get_websocket():
msg = json.loads(message)
# Check if this message belongs to the current context
# The default context may return null/None for context_id
received_ctx_id = msg.get("context_id")
if (
self._context_id is not None
and received_ctx_id is not None
and received_ctx_id != self._context_id
):
logger.trace(f"Ignoring message from different context: {received_ctx_id}")
continue
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()
@@ -393,20 +403,45 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
if msg.get("is_final"):
logger.trace(f"Received final message for context {received_ctx_id}")
# Context has finished
if self._context_id == received_ctx_id:
self._context_id = None
self._started = False
async def _keepalive_task_handler(self):
while True:
await asyncio.sleep(10)
try:
await self._send_text("")
# Send an empty message to keep the connection alive
if self._websocket and self._websocket.open:
await self._websocket.send(json.dumps({}))
except websockets.ConnectionClosed as e:
logger.warning(f"{self} keepalive error: {e}")
break
async def _send_text(self, text: str):
if self._websocket:
msg = {"text": text + " "}
await self._websocket.send(json.dumps(msg))
if not self._context_id:
# First message for a new context - need a space to initialize
msg = {"text": " ", "context_id": str(uuid.uuid4()), "xi_api_key": self._api_key}
# Add voice settings only in first message for a context
if self._voice_settings:
msg["voice_settings"] = self._voice_settings
await self._websocket.send(json.dumps(msg))
self._context_id = msg["context_id"]
logger.trace(f"Created new context {self._context_id}")
# Now send the actual text content
msg = {"text": text, "context_id": self._context_id}
await self._websocket.send(json.dumps(msg))
else:
# Continuing with an existing context
msg = {"text": text, "context_id": self._context_id}
await self._websocket.send(json.dumps(msg))
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
@@ -416,6 +451,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
await self._connect()
try:
# Close previous context if there was one
if self._context_id and not self._started:
await self._websocket.send(
json.dumps({"context_id": self._context_id, "close_context": True})
)
self._context_id = None
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
@@ -427,8 +469,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
except Exception as e:
logger.error(f"{self} error sending message: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
self._started = False
self._context_id = None
return
yield None
except Exception as e: