services: improve Cartesia, 11Labs, PlayHT and LMNT TTS reconnection

This commit is contained in:
Aleix Conchillo Flaqué
2024-12-06 09:25:04 -08:00
parent b05809be2e
commit bafb867ffc
6 changed files with 204 additions and 147 deletions

View File

@@ -68,6 +68,9 @@ async def on_audio_data(processor, audio, sample_rate, num_channels):
### Fixed
- Fixed Cartesia, ElevenLabs, LMNT and PlayHT TTS websocket
reconnection. Before, if an error occurred no reconnection was happening.
- Fixed a `BaseOutputTransport` issue that was causing audio to be discarded
after an `EndFrame` was received.

View File

@@ -184,28 +184,37 @@ class CartesiaTTSService(WordTTSService):
await self._disconnect()
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
async def _disconnect(self):
await self._disconnect_websocket()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
async def _connect_websocket(self):
try:
logger.debug("Connecting to Cartesia")
self._websocket = await websockets.connect(
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._websocket = None
async def _disconnect(self):
async def _disconnect_websocket(self):
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from Cartesia")
await self._websocket.close()
self._websocket = None
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
self._context_id = None
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
@@ -228,44 +237,51 @@ class CartesiaTTSService(WordTTSService):
await self._websocket.send(msg)
async def _receive_task_handler(self):
try:
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
await self.add_word_timestamps(
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
)
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]))
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"Cartesia error, unknown message type: {msg}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"{self} exception: {e}")
while True:
try:
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
await self.add_word_timestamps(
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
)
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(
zip(
msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]
)
)
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"{self} error, unknown message type: {msg}")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -386,8 +402,6 @@ class CartesiaHttpTTSService(TTSService):
_experimental_voice_controls=voice_controls,
)
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=output["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
@@ -398,4 +412,6 @@ class CartesiaHttpTTSService(TTSService):
logger.error(f"{self} exception: {e}")
await self.start_tts_usage_metrics(text)
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()

View File

@@ -192,15 +192,14 @@ class DeepgramSTTService(STTService):
yield None
async def _connect(self):
if await self._connection.start(self._settings):
logger.info(f"{self}: Connected to Deepgram")
else:
logger.error(f"{self}: Unable to connect to Deepgram")
logger.debug("Connecting to Deepgram")
if not await self._connection.start(self._settings):
logger.error(f"{self}: unable to connect to Deepgram")
async def _disconnect(self):
if self._connection.is_connected:
logger.debug("Disconnecting from Deepgram")
await self._connection.finish()
logger.info(f"{self}: Disconnected from Deepgram")
async def _on_speech_started(self, *args, **kwargs):
await self.start_ttfb_metrics()

View File

@@ -281,7 +281,28 @@ class ElevenLabsTTSService(WordTTSService):
await self.resume_processing_frames()
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler())
async def _disconnect(self):
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
if self._keepalive_task:
self._keepalive_task.cancel()
await self._keepalive_task
self._keepalive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
try:
logger.debug("Connecting to ElevenLabs")
voice_id = self._voice_id
model = self.model_name
output_format = self._settings["output_format"]
@@ -300,8 +321,6 @@ class ElevenLabsTTSService(WordTTSService):
)
self._websocket = await websockets.connect(url)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler())
# According to ElevenLabs, we should always start with a single space.
msg: Dict[str, Any] = {
@@ -315,49 +334,42 @@ class ElevenLabsTTSService(WordTTSService):
logger.error(f"{self} initialization error: {e}")
self._websocket = None
async def _disconnect(self):
async def _disconnect_websocket(self):
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from ElevenLabs")
await self._websocket.send(json.dumps({"text": ""}))
await self._websocket.close()
self._websocket = None
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
if self._keepalive_task:
self._keepalive_task.cancel()
await self._keepalive_task
self._keepalive_task = None
self._started = False
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
async def _receive_task_handler(self):
try:
async for message in self._websocket:
msg = json.loads(message)
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()
while True:
try:
async for message in self._websocket:
msg = json.loads(message)
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()
audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
await self.push_frame(frame)
if msg.get("alignment"):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"{self} exception: {e}")
audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
await self.push_frame(frame)
if msg.get("alignment"):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
async def _keepalive_task_handler(self):
while True:

View File

@@ -116,7 +116,22 @@ class LmntTTSService(TTSService):
self._started = False
async def _connect(self):
await self._connect_lmnt()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
async def _disconnect(self):
await self._disconnect_lmnt()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
async def _connect_lmnt(self):
try:
logger.debug("Connecting to LMNT")
self._speech = Speech()
self._connection = await self._speech.synthesize_streaming(
self._voice_id,
@@ -124,51 +139,51 @@ class LmntTTSService(TTSService):
sample_rate=self._settings["output_format"]["sample_rate"],
language=self._settings["language"],
)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
except Exception as e:
logger.exception(f"{self} initialization error: {e}")
logger.error(f"{self} initialization error: {e}")
self._connection = None
async def _disconnect(self):
async def _disconnect_lmnt(self):
try:
await self.stop_all_metrics()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
if self._connection:
logger.debug("Disconnecting from LMNT")
await self._connection.socket.close()
self._connection = None
if self._speech:
await self._speech.close()
self._speech = None
self._started = False
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")
logger.error(f"{self} error closing connection: {e}")
async def _receive_task_handler(self):
try:
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"LMNT error, unknown message type: {msg}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")
while True:
try:
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"{self}: LMNT error, unknown message type: {msg}")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_lmnt()
await self._connect_lmnt()
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")
@@ -194,4 +209,4 @@ class LmntTTSService(TTSService):
return
yield None
except Exception as e:
logger.exception(f"{self} exception: {e}")
logger.error(f"{self} exception: {e}")

View File

@@ -145,7 +145,22 @@ class PlayHTTTSService(TTSService):
await self._disconnect()
async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
async def _disconnect(self):
await self._disconnect_websocket()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
async def _connect_websocket(self):
try:
logger.debug("Connecting to PlayHT")
if not self._websocket_url:
await self._get_websocket_url()
@@ -153,8 +168,6 @@ class PlayHTTTSService(TTSService):
raise ValueError("WebSocket URL is not a string")
self._websocket = await websockets.connect(self._websocket_url)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
logger.debug("Connected to TTS WebSocket")
except ValueError as ve:
logger.error(f"{self} initialization error: {ve}")
self._websocket = None
@@ -162,19 +175,15 @@ class PlayHTTTSService(TTSService):
logger.error(f"{self} initialization error: {e}")
self._websocket = None
async def _disconnect(self):
async def _disconnect_websocket(self):
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from PlayHT")
await self._websocket.close()
self._websocket = None
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
self._request_id = None
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
@@ -209,31 +218,34 @@ class PlayHTTTSService(TTSService):
self._request_id = None
async def _receive_task_handler(self):
try:
async for message in self._get_websocket():
if isinstance(message, bytes):
# Skip the WAV header message
if message.startswith(b"RIFF"):
continue
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
await self.push_frame(frame)
else:
logger.debug(f"Received text message: {message}")
try:
msg = json.loads(message)
if "request_id" in msg and msg["request_id"] == self._request_id:
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"{self} exception in receive task: {e}")
while True:
try:
async for message in self._get_websocket():
if isinstance(message, bytes):
# Skip the WAV header message
if message.startswith(b"RIFF"):
continue
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
await self.push_frame(frame)
else:
logger.debug(f"Received text message: {message}")
try:
msg = json.loads(message)
if "request_id" in msg and msg["request_id"] == self._request_id:
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception in receive task: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -381,4 +393,4 @@ class PlayHTHttpTTSService(TTSService):
yield frame
yield TTSStoppedFrame()
except Exception as e:
logger.exception(f"{self} error generating TTS: {e}")
logger.error(f"{self} error generating TTS: {e}")