From bafb867ffc81bb786b55857d690dd9825facd2e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 6 Dec 2024 09:25:04 -0800 Subject: [PATCH] services: improve Cartesia, 11Labs, PlayHT and LMNT TTS reconnection --- CHANGELOG.md | 3 + src/pipecat/services/cartesia.py | 110 +++++++++++++++++------------ src/pipecat/services/deepgram.py | 9 ++- src/pipecat/services/elevenlabs.py | 74 +++++++++++-------- src/pipecat/services/lmnt.py | 75 ++++++++++++-------- src/pipecat/services/playht.py | 80 ++++++++++++--------- 6 files changed, 204 insertions(+), 147 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7939bf64..ef47a6f82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,9 @@ async def on_audio_data(processor, audio, sample_rate, num_channels): ### Fixed +- Fixed Cartesia, ElevenLabs, LMNT and PlayHT TTS websocket + reconnection. Before, if an error occurred no reconnection was happening. + - Fixed a `BaseOutputTransport` issue that was causing audio to be discarded after an `EndFrame` was received. diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 1d2e6e9a1..68d9f666d 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -184,28 +184,37 @@ class CartesiaTTSService(WordTTSService): await self._disconnect() async def _connect(self): + await self._connect_websocket() + + self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + + async def _disconnect(self): + await self._disconnect_websocket() + + if self._receive_task: + self._receive_task.cancel() + await self._receive_task + self._receive_task = None + + async def _connect_websocket(self): try: + logger.debug("Connecting to Cartesia") self._websocket = await websockets.connect( f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}" ) - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None - async def _disconnect(self): + async def _disconnect_websocket(self): try: await self.stop_all_metrics() if self._websocket: + logger.debug("Disconnecting from Cartesia") await self._websocket.close() self._websocket = None - if self._receive_task: - self._receive_task.cancel() - await self._receive_task - self._receive_task = None - self._context_id = None except Exception as e: logger.error(f"{self} error closing websocket: {e}") @@ -228,44 +237,51 @@ class CartesiaTTSService(WordTTSService): await self._websocket.send(msg) async def _receive_task_handler(self): - try: - async for message in self._get_websocket(): - msg = json.loads(message) - if not msg or msg["context_id"] != self._context_id: - continue - if msg["type"] == "done": - await self.stop_ttfb_metrics() - # Unset _context_id but not the _context_id_start_timestamp - # because we are likely still playing out audio and need the - # timestamp to set send context frames. - self._context_id = None - await self.add_word_timestamps( - [("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)] - ) - elif msg["type"] == "timestamps": - await self.add_word_timestamps( - list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"])) - ) - elif msg["type"] == "chunk": - await self.stop_ttfb_metrics() - self.start_word_timestamps() - frame = TTSAudioRawFrame( - audio=base64.b64decode(msg["data"]), - sample_rate=self._settings["output_format"]["sample_rate"], - num_channels=1, - ) - await self.push_frame(frame) - elif msg["type"] == "error": - logger.error(f"{self} error: {msg}") - await self.push_frame(TTSStoppedFrame()) - await self.stop_all_metrics() - await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) - else: - logger.error(f"Cartesia error, unknown message type: {msg}") - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"{self} exception: {e}") + while True: + try: + async for message in self._get_websocket(): + msg = json.loads(message) + if not msg or msg["context_id"] != self._context_id: + continue + if msg["type"] == "done": + await self.stop_ttfb_metrics() + # Unset _context_id but not the _context_id_start_timestamp + # because we are likely still playing out audio and need the + # timestamp to set send context frames. + self._context_id = None + await self.add_word_timestamps( + [("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)] + ) + elif msg["type"] == "timestamps": + await self.add_word_timestamps( + list( + zip( + msg["word_timestamps"]["words"], msg["word_timestamps"]["start"] + ) + ) + ) + elif msg["type"] == "chunk": + await self.stop_ttfb_metrics() + self.start_word_timestamps() + frame = TTSAudioRawFrame( + audio=base64.b64decode(msg["data"]), + sample_rate=self._settings["output_format"]["sample_rate"], + num_channels=1, + ) + await self.push_frame(frame) + elif msg["type"] == "error": + logger.error(f"{self} error: {msg}") + await self.push_frame(TTSStoppedFrame()) + await self.stop_all_metrics() + await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) + else: + logger.error(f"{self} error, unknown message type: {msg}") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"{self} exception: {e}") + await self._disconnect_websocket() + await self._connect_websocket() async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -386,8 +402,6 @@ class CartesiaHttpTTSService(TTSService): _experimental_voice_controls=voice_controls, ) - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( audio=output["audio"], sample_rate=self._settings["output_format"]["sample_rate"], @@ -398,4 +412,6 @@ class CartesiaHttpTTSService(TTSService): logger.error(f"{self} exception: {e}") await self.start_tts_usage_metrics(text) + + await self.stop_ttfb_metrics() yield TTSStoppedFrame() diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram.py index f322211d8..6578f2873 100644 --- a/src/pipecat/services/deepgram.py +++ b/src/pipecat/services/deepgram.py @@ -192,15 +192,14 @@ class DeepgramSTTService(STTService): yield None async def _connect(self): - if await self._connection.start(self._settings): - logger.info(f"{self}: Connected to Deepgram") - else: - logger.error(f"{self}: Unable to connect to Deepgram") + logger.debug("Connecting to Deepgram") + if not await self._connection.start(self._settings): + logger.error(f"{self}: unable to connect to Deepgram") async def _disconnect(self): if self._connection.is_connected: + logger.debug("Disconnecting from Deepgram") await self._connection.finish() - logger.info(f"{self}: Disconnected from Deepgram") async def _on_speech_started(self, *args, **kwargs): await self.start_ttfb_metrics() diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 2707df92a..87d48b4fa 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -281,7 +281,28 @@ class ElevenLabsTTSService(WordTTSService): await self.resume_processing_frames() async def _connect(self): + await self._connect_websocket() + + self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler()) + + async def _disconnect(self): + if self._receive_task: + self._receive_task.cancel() + await self._receive_task + self._receive_task = None + + if self._keepalive_task: + self._keepalive_task.cancel() + await self._keepalive_task + self._keepalive_task = None + + await self._disconnect_websocket() + + async def _connect_websocket(self): try: + logger.debug("Connecting to ElevenLabs") + voice_id = self._voice_id model = self.model_name output_format = self._settings["output_format"] @@ -300,8 +321,6 @@ class ElevenLabsTTSService(WordTTSService): ) self._websocket = await websockets.connect(url) - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) - self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler()) # According to ElevenLabs, we should always start with a single space. msg: Dict[str, Any] = { @@ -315,49 +334,42 @@ class ElevenLabsTTSService(WordTTSService): logger.error(f"{self} initialization error: {e}") self._websocket = None - async def _disconnect(self): + async def _disconnect_websocket(self): try: await self.stop_all_metrics() if self._websocket: + logger.debug("Disconnecting from ElevenLabs") await self._websocket.send(json.dumps({"text": ""})) await self._websocket.close() self._websocket = None - if self._receive_task: - self._receive_task.cancel() - await self._receive_task - self._receive_task = None - - if self._keepalive_task: - self._keepalive_task.cancel() - await self._keepalive_task - self._keepalive_task = None - self._started = False except Exception as e: logger.error(f"{self} error closing websocket: {e}") async def _receive_task_handler(self): - try: - async for message in self._websocket: - msg = json.loads(message) - if msg.get("audio"): - await self.stop_ttfb_metrics() - self.start_word_timestamps() + while True: + try: + async for message in self._websocket: + msg = json.loads(message) + if msg.get("audio"): + await self.stop_ttfb_metrics() + self.start_word_timestamps() - audio = base64.b64decode(msg["audio"]) - frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1) - await self.push_frame(frame) - - if msg.get("alignment"): - word_times = calculate_word_times(msg["alignment"], self._cumulative_time) - await self.add_word_timestamps(word_times) - self._cumulative_time = word_times[-1][1] - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"{self} exception: {e}") + audio = base64.b64decode(msg["audio"]) + frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1) + await self.push_frame(frame) + if msg.get("alignment"): + word_times = calculate_word_times(msg["alignment"], self._cumulative_time) + await self.add_word_timestamps(word_times) + self._cumulative_time = word_times[-1][1] + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"{self} exception: {e}") + await self._disconnect_websocket() + await self._connect_websocket() async def _keepalive_task_handler(self): while True: diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index a5c929094..04223a1a1 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -116,7 +116,22 @@ class LmntTTSService(TTSService): self._started = False async def _connect(self): + await self._connect_lmnt() + + self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + + async def _disconnect(self): + await self._disconnect_lmnt() + + if self._receive_task: + self._receive_task.cancel() + await self._receive_task + self._receive_task = None + + async def _connect_lmnt(self): try: + logger.debug("Connecting to LMNT") + self._speech = Speech() self._connection = await self._speech.synthesize_streaming( self._voice_id, @@ -124,51 +139,51 @@ class LmntTTSService(TTSService): sample_rate=self._settings["output_format"]["sample_rate"], language=self._settings["language"], ) - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) except Exception as e: - logger.exception(f"{self} initialization error: {e}") + logger.error(f"{self} initialization error: {e}") self._connection = None - async def _disconnect(self): + async def _disconnect_lmnt(self): try: await self.stop_all_metrics() - if self._receive_task: - self._receive_task.cancel() - await self._receive_task - self._receive_task = None if self._connection: + logger.debug("Disconnecting from LMNT") await self._connection.socket.close() self._connection = None if self._speech: await self._speech.close() self._speech = None + self._started = False except Exception as e: - logger.exception(f"{self} error closing websocket: {e}") + logger.error(f"{self} error closing connection: {e}") async def _receive_task_handler(self): - try: - async for msg in self._connection: - if "error" in msg: - logger.error(f'{self} error: {msg["error"]}') - await self.push_frame(TTSStoppedFrame()) - await self.stop_all_metrics() - await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) - elif "audio" in msg: - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( - audio=msg["audio"], - sample_rate=self._settings["output_format"]["sample_rate"], - num_channels=1, - ) - await self.push_frame(frame) - else: - logger.error(f"LMNT error, unknown message type: {msg}") - except asyncio.CancelledError: - pass - except Exception as e: - logger.exception(f"{self} exception: {e}") + while True: + try: + async for msg in self._connection: + if "error" in msg: + logger.error(f'{self} error: {msg["error"]}') + await self.push_frame(TTSStoppedFrame()) + await self.stop_all_metrics() + await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) + elif "audio" in msg: + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame( + audio=msg["audio"], + sample_rate=self._settings["output_format"]["sample_rate"], + num_channels=1, + ) + await self.push_frame(frame) + else: + logger.error(f"{self}: LMNT error, unknown message type: {msg}") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"{self} exception: {e}") + await self._disconnect_lmnt() + await self._connect_lmnt() async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") @@ -194,4 +209,4 @@ class LmntTTSService(TTSService): return yield None except Exception as e: - logger.exception(f"{self} exception: {e}") + logger.error(f"{self} exception: {e}") diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 9b5890554..d272ee361 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -145,7 +145,22 @@ class PlayHTTTSService(TTSService): await self._disconnect() async def _connect(self): + await self._connect_websocket() + + self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + + async def _disconnect(self): + await self._disconnect_websocket() + + if self._receive_task: + self._receive_task.cancel() + await self._receive_task + self._receive_task = None + + async def _connect_websocket(self): try: + logger.debug("Connecting to PlayHT") + if not self._websocket_url: await self._get_websocket_url() @@ -153,8 +168,6 @@ class PlayHTTTSService(TTSService): raise ValueError("WebSocket URL is not a string") self._websocket = await websockets.connect(self._websocket_url) - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) - logger.debug("Connected to TTS WebSocket") except ValueError as ve: logger.error(f"{self} initialization error: {ve}") self._websocket = None @@ -162,19 +175,15 @@ class PlayHTTTSService(TTSService): logger.error(f"{self} initialization error: {e}") self._websocket = None - async def _disconnect(self): + async def _disconnect_websocket(self): try: await self.stop_all_metrics() if self._websocket: + logger.debug("Disconnecting from PlayHT") await self._websocket.close() self._websocket = None - if self._receive_task: - self._receive_task.cancel() - await self._receive_task - self._receive_task = None - self._request_id = None except Exception as e: logger.error(f"{self} error closing websocket: {e}") @@ -209,31 +218,34 @@ class PlayHTTTSService(TTSService): self._request_id = None async def _receive_task_handler(self): - try: - async for message in self._get_websocket(): - if isinstance(message, bytes): - # Skip the WAV header message - if message.startswith(b"RIFF"): - continue - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1) - await self.push_frame(frame) - else: - logger.debug(f"Received text message: {message}") - try: - msg = json.loads(message) - if "request_id" in msg and msg["request_id"] == self._request_id: - await self.push_frame(TTSStoppedFrame()) - self._request_id = None - elif "error" in msg: - logger.error(f"{self} error: {msg}") - await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) - except json.JSONDecodeError: - logger.error(f"Invalid JSON message: {message}") - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"{self} exception in receive task: {e}") + while True: + try: + async for message in self._get_websocket(): + if isinstance(message, bytes): + # Skip the WAV header message + if message.startswith(b"RIFF"): + continue + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1) + await self.push_frame(frame) + else: + logger.debug(f"Received text message: {message}") + try: + msg = json.loads(message) + if "request_id" in msg and msg["request_id"] == self._request_id: + await self.push_frame(TTSStoppedFrame()) + self._request_id = None + elif "error" in msg: + logger.error(f"{self} error: {msg}") + await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) + except json.JSONDecodeError: + logger.error(f"Invalid JSON message: {message}") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"{self} exception in receive task: {e}") + await self._disconnect_websocket() + await self._connect_websocket() async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -381,4 +393,4 @@ class PlayHTHttpTTSService(TTSService): yield frame yield TTSStoppedFrame() except Exception as e: - logger.exception(f"{self} error generating TTS: {e}") + logger.error(f"{self} error generating TTS: {e}")