services: improve Cartesia, 11Labs, PlayHT and LMNT TTS reconnection
This commit is contained in:
@@ -68,6 +68,9 @@ async def on_audio_data(processor, audio, sample_rate, num_channels):
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Cartesia, ElevenLabs, LMNT and PlayHT TTS websocket
|
||||
reconnection. Before, if an error occurred no reconnection was happening.
|
||||
|
||||
- Fixed a `BaseOutputTransport` issue that was causing audio to be discarded
|
||||
after an `EndFrame` was received.
|
||||
|
||||
|
||||
@@ -184,28 +184,37 @@ class CartesiaTTSService(WordTTSService):
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
logger.debug("Connecting to Cartesia")
|
||||
self._websocket = await websockets.connect(
|
||||
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
|
||||
)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from Cartesia")
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
|
||||
self._context_id = None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
@@ -228,44 +237,51 @@ class CartesiaTTSService(WordTTSService):
|
||||
await self._websocket.send(msg)
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
if not msg or msg["context_id"] != self._context_id:
|
||||
continue
|
||||
if msg["type"] == "done":
|
||||
await self.stop_ttfb_metrics()
|
||||
# Unset _context_id but not the _context_id_start_timestamp
|
||||
# because we are likely still playing out audio and need the
|
||||
# timestamp to set send context frames.
|
||||
self._context_id = None
|
||||
await self.add_word_timestamps(
|
||||
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
|
||||
)
|
||||
elif msg["type"] == "timestamps":
|
||||
await self.add_word_timestamps(
|
||||
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]))
|
||||
)
|
||||
elif msg["type"] == "chunk":
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
elif msg["type"] == "error":
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
|
||||
else:
|
||||
logger.error(f"Cartesia error, unknown message type: {msg}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
while True:
|
||||
try:
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
if not msg or msg["context_id"] != self._context_id:
|
||||
continue
|
||||
if msg["type"] == "done":
|
||||
await self.stop_ttfb_metrics()
|
||||
# Unset _context_id but not the _context_id_start_timestamp
|
||||
# because we are likely still playing out audio and need the
|
||||
# timestamp to set send context frames.
|
||||
self._context_id = None
|
||||
await self.add_word_timestamps(
|
||||
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
|
||||
)
|
||||
elif msg["type"] == "timestamps":
|
||||
await self.add_word_timestamps(
|
||||
list(
|
||||
zip(
|
||||
msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]
|
||||
)
|
||||
)
|
||||
)
|
||||
elif msg["type"] == "chunk":
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
elif msg["type"] == "error":
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
|
||||
else:
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self._disconnect_websocket()
|
||||
await self._connect_websocket()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -386,8 +402,6 @@ class CartesiaHttpTTSService(TTSService):
|
||||
_experimental_voice_controls=voice_controls,
|
||||
)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=output["audio"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
@@ -398,4 +412,6 @@ class CartesiaHttpTTSService(TTSService):
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -192,15 +192,14 @@ class DeepgramSTTService(STTService):
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
if await self._connection.start(self._settings):
|
||||
logger.info(f"{self}: Connected to Deepgram")
|
||||
else:
|
||||
logger.error(f"{self}: Unable to connect to Deepgram")
|
||||
logger.debug("Connecting to Deepgram")
|
||||
if not await self._connection.start(self._settings):
|
||||
logger.error(f"{self}: unable to connect to Deepgram")
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._connection.is_connected:
|
||||
logger.debug("Disconnecting from Deepgram")
|
||||
await self._connection.finish()
|
||||
logger.info(f"{self}: Disconnected from Deepgram")
|
||||
|
||||
async def _on_speech_started(self, *args, **kwargs):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -281,7 +281,28 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
await self.resume_processing_frames()
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
|
||||
if self._keepalive_task:
|
||||
self._keepalive_task.cancel()
|
||||
await self._keepalive_task
|
||||
self._keepalive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
logger.debug("Connecting to ElevenLabs")
|
||||
|
||||
voice_id = self._voice_id
|
||||
model = self.model_name
|
||||
output_format = self._settings["output_format"]
|
||||
@@ -300,8 +321,6 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
)
|
||||
|
||||
self._websocket = await websockets.connect(url)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler())
|
||||
|
||||
# According to ElevenLabs, we should always start with a single space.
|
||||
msg: Dict[str, Any] = {
|
||||
@@ -315,49 +334,42 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from ElevenLabs")
|
||||
await self._websocket.send(json.dumps({"text": ""}))
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
|
||||
if self._keepalive_task:
|
||||
self._keepalive_task.cancel()
|
||||
await self._keepalive_task
|
||||
self._keepalive_task = None
|
||||
|
||||
self._started = False
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
msg = json.loads(message)
|
||||
if msg.get("audio"):
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
while True:
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
msg = json.loads(message)
|
||||
if msg.get("audio"):
|
||||
await self.stop_ttfb_metrics()
|
||||
self.start_word_timestamps()
|
||||
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
|
||||
await self.push_frame(frame)
|
||||
|
||||
if msg.get("alignment"):
|
||||
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
|
||||
await self.add_word_timestamps(word_times)
|
||||
self._cumulative_time = word_times[-1][1]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
|
||||
await self.push_frame(frame)
|
||||
if msg.get("alignment"):
|
||||
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
|
||||
await self.add_word_timestamps(word_times)
|
||||
self._cumulative_time = word_times[-1][1]
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self._disconnect_websocket()
|
||||
await self._connect_websocket()
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
while True:
|
||||
|
||||
@@ -116,7 +116,22 @@ class LmntTTSService(TTSService):
|
||||
self._started = False
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_lmnt()
|
||||
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_lmnt()
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
|
||||
async def _connect_lmnt(self):
|
||||
try:
|
||||
logger.debug("Connecting to LMNT")
|
||||
|
||||
self._speech = Speech()
|
||||
self._connection = await self._speech.synthesize_streaming(
|
||||
self._voice_id,
|
||||
@@ -124,51 +139,51 @@ class LmntTTSService(TTSService):
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
language=self._settings["language"],
|
||||
)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} initialization error: {e}")
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._connection = None
|
||||
|
||||
async def _disconnect(self):
|
||||
async def _disconnect_lmnt(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
if self._connection:
|
||||
logger.debug("Disconnecting from LMNT")
|
||||
await self._connection.socket.close()
|
||||
self._connection = None
|
||||
if self._speech:
|
||||
await self._speech.close()
|
||||
self._speech = None
|
||||
|
||||
self._started = False
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error closing websocket: {e}")
|
||||
logger.error(f"{self} error closing connection: {e}")
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for msg in self._connection:
|
||||
if "error" in msg:
|
||||
logger.error(f'{self} error: {msg["error"]}')
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
|
||||
elif "audio" in msg:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=msg["audio"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
logger.error(f"LMNT error, unknown message type: {msg}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
while True:
|
||||
try:
|
||||
async for msg in self._connection:
|
||||
if "error" in msg:
|
||||
logger.error(f'{self} error: {msg["error"]}')
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
|
||||
elif "audio" in msg:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=msg["audio"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
logger.error(f"{self}: LMNT error, unknown message type: {msg}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self._disconnect_lmnt()
|
||||
await self._connect_lmnt()
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
@@ -194,4 +209,4 @@ class LmntTTSService(TTSService):
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
@@ -145,7 +145,22 @@ class PlayHTTTSService(TTSService):
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
await self._connect_websocket()
|
||||
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
logger.debug("Connecting to PlayHT")
|
||||
|
||||
if not self._websocket_url:
|
||||
await self._get_websocket_url()
|
||||
|
||||
@@ -153,8 +168,6 @@ class PlayHTTTSService(TTSService):
|
||||
raise ValueError("WebSocket URL is not a string")
|
||||
|
||||
self._websocket = await websockets.connect(self._websocket_url)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
logger.debug("Connected to TTS WebSocket")
|
||||
except ValueError as ve:
|
||||
logger.error(f"{self} initialization error: {ve}")
|
||||
self._websocket = None
|
||||
@@ -162,19 +175,15 @@ class PlayHTTTSService(TTSService):
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from PlayHT")
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
|
||||
self._request_id = None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
@@ -209,31 +218,34 @@ class PlayHTTTSService(TTSService):
|
||||
self._request_id = None
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._get_websocket():
|
||||
if isinstance(message, bytes):
|
||||
# Skip the WAV header message
|
||||
if message.startswith(b"RIFF"):
|
||||
continue
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
logger.debug(f"Received text message: {message}")
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
if "request_id" in msg and msg["request_id"] == self._request_id:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
self._request_id = None
|
||||
elif "error" in msg:
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception in receive task: {e}")
|
||||
while True:
|
||||
try:
|
||||
async for message in self._get_websocket():
|
||||
if isinstance(message, bytes):
|
||||
# Skip the WAV header message
|
||||
if message.startswith(b"RIFF"):
|
||||
continue
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
logger.debug(f"Received text message: {message}")
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
if "request_id" in msg and msg["request_id"] == self._request_id:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
self._request_id = None
|
||||
elif "error" in msg:
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception in receive task: {e}")
|
||||
await self._disconnect_websocket()
|
||||
await self._connect_websocket()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -381,4 +393,4 @@ class PlayHTHttpTTSService(TTSService):
|
||||
yield frame
|
||||
yield TTSStoppedFrame()
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
|
||||
Reference in New Issue
Block a user