Compare commits

...

1 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
653fbb7e3e services: fix infinite websocket-bases TTS services retries
Fixes #871
2024-12-16 15:14:22 -08:00
6 changed files with 173 additions and 99 deletions

View File

@@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.0.51] - 2024-12-16
### Fixed
- Fixed an issue in websocket-based TTS services that was causing infinite
reconnections (Cartesia, ElevenLabs, PlayHT and LMNT).
## [0.0.50] - 2024-12-11 ## [0.0.50] - 2024-12-11
### Added ### Added

View File

@@ -29,6 +29,7 @@ dependencies = [
"pydantic~=2.8.2", "pydantic~=2.8.2",
"pyloudnorm~=0.1.1", "pyloudnorm~=0.1.1",
"resampy~=0.4.3", "resampy~=0.4.3",
"tenacity~=9.0.0"
] ]
[project.urls] [project.urls]
@@ -55,7 +56,7 @@ gstreamer = [ "pygobject~=3.48.2" ]
fireworks = [ "openai~=1.50.2" ] fireworks = [ "openai~=1.50.2" ]
krisp = [ "pipecat-ai-krisp~=0.3.0" ] krisp = [ "pipecat-ai-krisp~=0.3.0" ]
langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ] langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ]
livekit = [ "livekit~=0.17.5", "livekit-api~=0.7.1", "tenacity~=8.5.0" ] livekit = [ "livekit~=0.17.5", "livekit-api~=0.7.1" ]
lmnt = [ "lmnt~=1.1.4" ] lmnt = [ "lmnt~=1.1.4" ]
local = [ "pyaudio~=0.2.14" ] local = [ "pyaudio~=0.2.14" ]
moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ] moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ]

View File

@@ -12,6 +12,8 @@ from typing import AsyncGenerator, List, Optional, Union
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential
from pipecat.frames.frames import ( from pipecat.frames.frames import (
BotStoppedSpeakingFrame, BotStoppedSpeakingFrame,
@@ -239,52 +241,64 @@ class CartesiaTTSService(WordTTSService):
msg = self._build_msg(text="", continue_transcript=False) msg = self._build_msg(text="", continue_transcript=False)
await self._websocket.send(msg) await self._websocket.send(msg)
async def _receive_messages(self):
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
await self.add_word_timestamps(
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
)
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]))
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"{self} error, unknown message type: {msg}")
async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()
async def _receive_task_handler(self): async def _receive_task_handler(self):
while True: while True:
try: try:
async for message in self._get_websocket(): async for attempt in AsyncRetrying(
msg = json.loads(message) stop=stop_after_attempt(3),
if not msg or msg["context_id"] != self._context_id: wait=wait_exponential(multiplier=1, min=4, max=10),
continue before_sleep=self._reconnect_websocket,
if msg["type"] == "done": reraise=True,
await self.stop_ttfb_metrics() ):
# Unset _context_id but not the _context_id_start_timestamp with attempt:
# because we are likely still playing out audio and need the await self._receive_messages()
# timestamp to set send context frames.
self._context_id = None
await self.add_word_timestamps(
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
)
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(
zip(
msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]
)
)
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"{self} error, unknown message type: {msg}")
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"{self} exception: {e}") message = f"{self} error receiving messages: {e}"
await self._disconnect_websocket() logger.error(message)
await self._connect_websocket() await self.push_error(ErrorFrame(message, fatal=True))
break
async def process_frame(self, frame: Frame, direction: FrameDirection): async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction) await super().process_frame(frame, direction)

View File

@@ -11,11 +11,13 @@ from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional,
from loguru import logger from loguru import logger
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential
from pipecat.frames.frames import ( from pipecat.frames.frames import (
BotStoppedSpeakingFrame, BotStoppedSpeakingFrame,
CancelFrame, CancelFrame,
EndFrame, EndFrame,
ErrorFrame,
Frame, Frame,
LLMFullResponseEndFrame, LLMFullResponseEndFrame,
StartFrame, StartFrame,
@@ -348,28 +350,44 @@ class ElevenLabsTTSService(WordTTSService):
except Exception as e: except Exception as e:
logger.error(f"{self} error closing websocket: {e}") logger.error(f"{self} error closing websocket: {e}")
async def _receive_messages(self):
async for message in self._websocket:
msg = json.loads(message)
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()
audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
await self.push_frame(frame)
if msg.get("alignment"):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()
async def _receive_task_handler(self): async def _receive_task_handler(self):
while True: while True:
try: try:
async for message in self._websocket: async for attempt in AsyncRetrying(
msg = json.loads(message) stop=stop_after_attempt(3),
if msg.get("audio"): wait=wait_exponential(multiplier=1, min=4, max=10),
await self.stop_ttfb_metrics() before_sleep=self._reconnect_websocket,
self.start_word_timestamps() reraise=True,
):
audio = base64.b64decode(msg["audio"]) with attempt:
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1) await self._receive_messages()
await self.push_frame(frame)
if msg.get("alignment"):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"{self} exception: {e}") message = f"{self} error receiving messages: {e}"
await self._disconnect_websocket() logger.error(message)
await self._connect_websocket() await self.push_error(ErrorFrame(message, fatal=True))
break
async def _keepalive_task_handler(self): async def _keepalive_task_handler(self):
while True: while True:

View File

@@ -8,6 +8,7 @@ import asyncio
from typing import AsyncGenerator from typing import AsyncGenerator
from loguru import logger from loguru import logger
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential
from pipecat.frames.frames import ( from pipecat.frames.frames import (
CancelFrame, CancelFrame,
@@ -159,31 +160,47 @@ class LmntTTSService(TTSService):
except Exception as e: except Exception as e:
logger.error(f"{self} error closing connection: {e}") logger.error(f"{self} error closing connection: {e}")
async def _receive_messages(self):
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"{self}: LMNT error, unknown message type: {msg}")
async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_lmnt()
await self._connect_lmnt()
async def _receive_task_handler(self): async def _receive_task_handler(self):
while True: while True:
try: try:
async for msg in self._connection: async for attempt in AsyncRetrying(
if "error" in msg: stop=stop_after_attempt(3),
logger.error(f'{self} error: {msg["error"]}') wait=wait_exponential(multiplier=1, min=4, max=10),
await self.push_frame(TTSStoppedFrame()) before_sleep=self._reconnect_websocket,
await self.stop_all_metrics() reraise=True,
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) ):
elif "audio" in msg: with attempt:
await self.stop_ttfb_metrics() await self._receive_messages()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"{self}: LMNT error, unknown message type: {msg}")
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"{self} exception: {e}") message = f"{self} error receiving messages: {e}"
await self._disconnect_lmnt() logger.error(message)
await self._connect_lmnt() await self.push_error(ErrorFrame(message, fatal=True))
break
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]") logger.debug(f"Generating TTS: [{text}]")

View File

@@ -15,6 +15,7 @@ import aiohttp
import websockets import websockets
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential
from pipecat.frames.frames import ( from pipecat.frames.frames import (
BotStoppedSpeakingFrame, BotStoppedSpeakingFrame,
@@ -217,35 +218,51 @@ class PlayHTTTSService(TTSService):
await self.stop_all_metrics() await self.stop_all_metrics()
self._request_id = None self._request_id = None
async def _receive_messages(self):
async for message in self._get_websocket():
if isinstance(message, bytes):
# Skip the WAV header message
if message.startswith(b"RIFF"):
continue
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
await self.push_frame(frame)
else:
logger.debug(f"Received text message: {message}")
try:
msg = json.loads(message)
if "request_id" in msg and msg["request_id"] == self._request_id:
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()
async def _receive_task_handler(self): async def _receive_task_handler(self):
while True: while True:
try: try:
async for message in self._get_websocket(): async for attempt in AsyncRetrying(
if isinstance(message, bytes): stop=stop_after_attempt(3),
# Skip the WAV header message wait=wait_exponential(multiplier=1, min=4, max=10),
if message.startswith(b"RIFF"): before_sleep=self._reconnect_websocket,
continue reraise=True,
await self.stop_ttfb_metrics() ):
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1) with attempt:
await self.push_frame(frame) await self._receive_messages()
else:
logger.debug(f"Received text message: {message}")
try:
msg = json.loads(message)
if "request_id" in msg and msg["request_id"] == self._request_id:
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"{self} exception in receive task: {e}") message = f"{self} error receiving messages: {e}"
await self._disconnect_websocket() logger.error(message)
await self._connect_websocket() await self.push_error(ErrorFrame(message, fatal=True))
break
async def process_frame(self, frame: Frame, direction: FrameDirection): async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction) await super().process_frame(frame, direction)