Merge pull request #1761 from pipecat-ai/mb/elevenlabs-context-id
Update ElevenLabsTTSService to use the new websocket API
This commit is contained in:
@@ -28,6 +28,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `ElevenLabsTTSService` to use the beta websocket API
|
||||
(multi-stream-input). This new API supports context_ids and cancelling those
|
||||
contexts, which greatly improves interruption handling.
|
||||
|
||||
- Observers `on_push_frame()` now take a single argument `FramePushed` instead
|
||||
of multiple arguments.
|
||||
|
||||
|
||||
@@ -7,11 +7,12 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
@@ -26,7 +27,10 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import InterruptibleWordTTSService, WordTTSService
|
||||
from pipecat.services.tts_service import (
|
||||
AudioContextWordTTSService,
|
||||
WordTTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for ElevenLabs configuration needed
|
||||
@@ -159,10 +163,9 @@ def calculate_word_times(
|
||||
return word_times
|
||||
|
||||
|
||||
class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = None
|
||||
optimize_streaming_latency: Optional[str] = None
|
||||
stability: Optional[float] = None
|
||||
similarity_boost: Optional[float] = None
|
||||
style: Optional[float] = None
|
||||
@@ -172,16 +175,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
enable_ssml_parsing: Optional[bool] = None
|
||||
enable_logging: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_voice_settings(self):
|
||||
stability = self.stability
|
||||
similarity_boost = self.similarity_boost
|
||||
if (stability is None) != (similarity_boost is None):
|
||||
raise ValueError(
|
||||
"Both 'stability' and 'similarity_boost' must be provided when using voice settings"
|
||||
)
|
||||
return self
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -222,7 +215,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else None,
|
||||
"optimize_streaming_latency": params.optimize_streaming_latency,
|
||||
"stability": params.stability,
|
||||
"similarity_boost": params.similarity_boost,
|
||||
"style": params.style,
|
||||
@@ -242,6 +234,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
self._started = False
|
||||
self._cumulative_time = 0
|
||||
|
||||
# Context management for v1 multi API
|
||||
self._context_id = None
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
|
||||
@@ -257,15 +251,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
async def set_model(self, model: str):
|
||||
await super().set_model(model)
|
||||
logger.info(f"Switching TTS model to: [{model}]")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
# No need to disconnect/reconnect for model changes with multi-context API
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
# If voice changes, we don't need to reconnect, just use a new context
|
||||
if not prev_voice == self._voice_id:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
logger.info(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
@@ -282,8 +274,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
await self._disconnect()
|
||||
|
||||
async def flush_audio(self):
|
||||
if self._websocket:
|
||||
msg = {"text": " ", "flush": True}
|
||||
if self._websocket and self._context_id:
|
||||
msg = {"context_id": self._context_id, "flush": True}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -323,10 +315,7 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
voice_id = self._voice_id
|
||||
model = self.model_name
|
||||
output_format = self._output_format
|
||||
url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}"
|
||||
|
||||
if self._settings["optimize_streaming_latency"]:
|
||||
url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}"
|
||||
url = f"{self._url}/v1/text-to-speech/{voice_id}/multi-stream-input?model_id={model}&output_format={output_format}&auto_mode={self._settings['auto_mode']}"
|
||||
|
||||
if self._settings["enable_ssml_parsing"]:
|
||||
url += f"&enable_ssml_parsing={self._settings['enable_ssml_parsing']}"
|
||||
@@ -347,14 +336,6 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
# Set max websocket message size to 16MB for large audio responses
|
||||
self._websocket = await websockets.connect(url, max_size=16 * 1024 * 1024)
|
||||
|
||||
# According to ElevenLabs, we should always start with a single space.
|
||||
msg: Dict[str, Any] = {
|
||||
"text": " ",
|
||||
"xi_api_key": self._api_key,
|
||||
}
|
||||
if self._voice_settings:
|
||||
msg["voice_settings"] = self._voice_settings
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -366,12 +347,15 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from ElevenLabs")
|
||||
await self._websocket.send(json.dumps({"text": ""}))
|
||||
# Close all contexts and the socket
|
||||
if self._context_id:
|
||||
await self._websocket.send(json.dumps({"close_socket": True}))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
finally:
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
|
||||
def _get_websocket(self):
|
||||
@@ -379,9 +363,35 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
# Close the current context when interrupted without closing the websocket
|
||||
if self._context_id and self._websocket:
|
||||
logger.trace(f"Closing context {self._context_id} due to interruption")
|
||||
try:
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing context on interruption: {e}")
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
# Check if this message belongs to the current context
|
||||
# The default context may return null/None for context_id
|
||||
received_ctx_id = msg.get("context_id")
|
||||
if (
|
||||
self._context_id is not None
|
||||
and received_ctx_id is not None
|
||||
and received_ctx_id != self._context_id
|
||||
):
|
||||
logger.trace(f"Ignoring message from different context: {received_ctx_id}")
|
||||
continue
|
||||
|
||||
if msg.get("audio"):
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
@@ -393,20 +403,45 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
|
||||
await self.add_word_timestamps(word_times)
|
||||
self._cumulative_time = word_times[-1][1]
|
||||
if msg.get("is_final"):
|
||||
logger.trace(f"Received final message for context {received_ctx_id}")
|
||||
# Context has finished
|
||||
if self._context_id == received_ctx_id:
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
try:
|
||||
await self._send_text("")
|
||||
# Send an empty message to keep the connection alive
|
||||
if self._websocket and self._websocket.open:
|
||||
await self._websocket.send(json.dumps({}))
|
||||
except websockets.ConnectionClosed as e:
|
||||
logger.warning(f"{self} keepalive error: {e}")
|
||||
break
|
||||
|
||||
async def _send_text(self, text: str):
|
||||
if self._websocket:
|
||||
msg = {"text": text + " "}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
if not self._context_id:
|
||||
# First message for a new context - need a space to initialize
|
||||
msg = {"text": " ", "context_id": str(uuid.uuid4()), "xi_api_key": self._api_key}
|
||||
|
||||
# Add voice settings only in first message for a context
|
||||
if self._voice_settings:
|
||||
msg["voice_settings"] = self._voice_settings
|
||||
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
self._context_id = msg["context_id"]
|
||||
logger.trace(f"Created new context {self._context_id}")
|
||||
|
||||
# Now send the actual text content
|
||||
msg = {"text": text, "context_id": self._context_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
else:
|
||||
# Continuing with an existing context
|
||||
msg = {"text": text, "context_id": self._context_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
@@ -416,6 +451,13 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
# Close previous context if there was one
|
||||
if self._context_id and not self._started:
|
||||
await self._websocket.send(
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
self._context_id = None
|
||||
|
||||
if not self._started:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
@@ -427,8 +469,8 @@ class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user