Merge pull request #3729 from pipecat-ai/filipi/elevenlabs_issue
TTS services fixes.
This commit is contained in:
1
changelog/3729.fixed.2.md
Normal file
1
changelog/3729.fixed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed context ID reuse issue in `ElevenLabsTTSService`, `InworldTTSService`, `RimeTTSService`, `CartesiaTTSService`, `AsyncAITTSService`, and `PlayHTTTSService`. Services now properly reuse the same context ID across multiple `run_tts()` invocations within a single LLM turn, preventing context tracking issues and incorrect lifecycle signaling.
|
||||
1
changelog/3729.fixed.md
Normal file
1
changelog/3729.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed word timestamp interleaving issue in `ElevenLabsTTSService` when processing multiple sentences within a single LLM turn.
|
||||
@@ -9,6 +9,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -270,6 +271,20 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
|
||||
|
||||
Returns:
|
||||
A unique string identifier for the TTS context.
|
||||
"""
|
||||
# If a context ID does not exist, create a new one.
|
||||
# If an ID exists, continue using the current ID.
|
||||
# When interruptions happen, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
return str(uuid.uuid4())
|
||||
return self._context_id
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio."""
|
||||
if not self._context_id or not self._websocket:
|
||||
@@ -379,13 +394,14 @@ class AsyncAITTSService(AudioContextTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
if not self._context_id:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
self._context_id = context_id
|
||||
if not self.audio_context_available(self._context_id):
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
if not self.audio_context_available(self._context_id):
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
|
||||
await self._get_websocket().send(msg)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Literal, Optional
|
||||
@@ -539,6 +540,20 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
await self._get_websocket().send(cancel_msg)
|
||||
self._context_id = None
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
|
||||
|
||||
Returns:
|
||||
A unique string identifier for the TTS context.
|
||||
"""
|
||||
# If a context ID does not exist, create a new one.
|
||||
# If an ID exists, continue using the current ID.
|
||||
# When interruptions happen, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
return str(uuid.uuid4())
|
||||
return self._context_id
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio and finalize the current context."""
|
||||
if not self._context_id or not self._websocket:
|
||||
|
||||
@@ -13,6 +13,7 @@ with support for streaming audio, word timestamps, and voice customization.
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
@@ -680,6 +681,20 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
msg = {"text": text, "context_id": self._context_id}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
|
||||
|
||||
Returns:
|
||||
A unique string identifier for the TTS context.
|
||||
"""
|
||||
# If a context ID does not exist, create a new one.
|
||||
# If an ID exists, continue using the current ID.
|
||||
# When interruptions happens, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
return str(uuid.uuid4())
|
||||
return self._context_id
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using ElevenLabs' streaming WebSocket API.
|
||||
@@ -698,31 +713,28 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._cumulative_time = 0
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
# If a context ID does not exist, use the provided one.
|
||||
# If an ID exists, that means the Pipeline doesn't allow
|
||||
# user interruptions, so continue using the current ID.
|
||||
# When interruptions are allowed, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
if not self.audio_context_available(self._context_id):
|
||||
await self.create_audio_context(self._context_id)
|
||||
self._cumulative_time = 0
|
||||
self._partial_word = ""
|
||||
self._partial_word_start_time = 0.0
|
||||
|
||||
# Initialize context with voice settings and pronunciation dictionaries
|
||||
msg = {"text": " ", "context_id": self._context_id}
|
||||
if self._voice_settings:
|
||||
msg["voice_settings"] = self._voice_settings
|
||||
if self._pronunciation_dictionary_locators:
|
||||
msg["pronunciation_dictionary_locators"] = [
|
||||
locator.model_dump() for locator in self._pronunciation_dictionary_locators
|
||||
]
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
logger.trace(f"Created new context {self._context_id}")
|
||||
if not self.audio_context_available(self._context_id):
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
# Initialize context with voice settings and pronunciation dictionaries
|
||||
msg = {"text": " ", "context_id": self._context_id}
|
||||
if self._voice_settings:
|
||||
msg["voice_settings"] = self._voice_settings
|
||||
if self._pronunciation_dictionary_locators:
|
||||
msg["pronunciation_dictionary_locators"] = [
|
||||
locator.model_dump()
|
||||
for locator in self._pronunciation_dictionary_locators
|
||||
]
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
logger.trace(f"Created new context {self._context_id}")
|
||||
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
@@ -924,6 +924,20 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
msg = {"close_context": {}, "contextId": context_id}
|
||||
await self.send_with_retry(json.dumps(msg), self._report_error)
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
|
||||
|
||||
Returns:
|
||||
A unique string identifier for the TTS context.
|
||||
"""
|
||||
# If a context ID does not exist, create a new one.
|
||||
# If an ID exists, continue using the current ID.
|
||||
# When interruptions happen, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
return str(uuid.uuid4())
|
||||
return self._context_id
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate TTS audio for the given text using the Inworld WebSocket TTS service.
|
||||
@@ -942,10 +956,9 @@ class InworldTTSService(AudioContextWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
if not self._context_id:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
self._context_id = context_id
|
||||
logger.trace(f"{self}: Creating new context {self._context_id}")
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
@@ -13,6 +13,7 @@ supporting both WebSocket streaming and HTTP-based synthesis.
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
@@ -323,6 +324,20 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
|
||||
|
||||
Returns:
|
||||
A unique string identifier for the TTS context.
|
||||
"""
|
||||
# If a context ID does not exist, create a new one.
|
||||
# If an ID exists, continue using the current ID.
|
||||
# When interruptions happen, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
return str(uuid.uuid4())
|
||||
return self._context_id
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by stopping metrics and clearing request ID."""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
@@ -12,6 +12,7 @@ using Rime's API for streaming and batch audio synthesis.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -369,6 +370,20 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
return word_pairs
|
||||
|
||||
def create_context_id(self) -> str:
|
||||
"""Generate a unique context ID for a TTS request in case we don't have one already in progress.
|
||||
|
||||
Returns:
|
||||
A unique string identifier for the TTS context.
|
||||
"""
|
||||
# If a context ID does not exist, create a new one.
|
||||
# If an ID exists, continue using the current ID.
|
||||
# When interruptions happen, user speech results in
|
||||
# an interruption, which resets the context ID.
|
||||
if not self._context_id:
|
||||
return str(uuid.uuid4())
|
||||
return self._context_id
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._context_id or not self._websocket:
|
||||
|
||||
Reference in New Issue
Block a user