services(cartesia): change to subclass of AsyncWordTTSService

This commit is contained in:
Aleix Conchillo Flaqué
2024-09-12 00:19:37 -07:00
parent 02d926e9bd
commit 80f6d74e80

View File

@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.transcriptions.language import Language
from pipecat.services.ai_services import TTSService
from pipecat.services.ai_services import AsyncWordTTSService
from loguru import logger
@@ -60,7 +60,7 @@ def language_to_cartesia_language(language: Language) -> str | None:
return None
class CartesiaTTSService(TTSService):
class CartesiaTTSService(AsyncWordTTSService):
def __init__(
self,
@@ -74,19 +74,17 @@ class CartesiaTTSService(TTSService):
sample_rate: int = 16000,
language: str = "en",
**kwargs):
super().__init__(**kwargs)
# Aggregating sentences still gives cleaner-sounding results and fewer
# artifacts than streaming one word at a time. On average, waiting for
# a full sentence should only "cost" us 15ms or so with GPT-4o or a Llama 3
# model, and it's worth it for the better audio quality.
self._aggregate_sentences = True
# we don't want to automatically push LLM response text frames, because the
# context aggregators will add them to the LLM context even if we're
# interrupted. cartesia gives us word-by-word timestamps. we can use those
# to generate text frames ourselves aligned with the playout timing of the audio!
self._push_text_frames = False
# artifacts than streaming one word at a time. On average, waiting for a
# full sentence should only "cost" us 15ms or so with GPT-4o or a Llama
# 3 model, and it's worth it for the better audio quality.
#
# We also don't want to automatically push LLM response text frames,
# because the context aggregators will add them to the LLM context even
# if we're interrupted. Cartesia gives us word-by-word timestamps. We
# can use those to generate text frames ourselves aligned with the
# playout timing of the audio!
super().__init__(aggregate_sentences=True, push_text_frames=False, **kwargs)
self._api_key = api_key
self._cartesia_version = cartesia_version
@@ -102,10 +100,7 @@ class CartesiaTTSService(TTSService):
self._websocket = None
self._context_id = None
self._context_id_start_timestamp = None
self._timestamped_words_buffer = []
self._receive_task = None
self._context_appending_task = None
def can_generate_metrics(self) -> bool:
return True
@@ -140,7 +135,6 @@ class CartesiaTTSService(TTSService):
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._context_appending_task = self.get_event_loop().create_task(self._context_appending_task_handler())
except Exception as e:
logger.exception(f"{self} initialization error: {e}")
self._websocket = None
@@ -149,10 +143,6 @@ class CartesiaTTSService(TTSService):
try:
await self.stop_all_metrics()
if self._context_appending_task:
self._context_appending_task.cancel()
await self._context_appending_task
self._context_appending_task = None
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
@@ -162,18 +152,14 @@ class CartesiaTTSService(TTSService):
self._websocket = None
self._context_id = None
self._context_id_start_timestamp = None
self._timestamped_words_buffer = []
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
self._context_id = None
self._context_id_start_timestamp = None
self._timestamped_words_buffer = []
await self.stop_all_metrics()
await self.push_frame(LLMFullResponseEndFrame())
self._context_id = None
async def _receive_task_handler(self):
try:
@@ -188,16 +174,14 @@ class CartesiaTTSService(TTSService):
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0))
await self.add_word_timestamps([("LLMFullResponseEndFrame", 0)])
elif msg["type"] == "timestamps":
# logger.debug(f"TIMESTAMPS: {msg}")
self._timestamped_words_buffer.extend(
await self.add_word_timestamps(
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["end"]))
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
if not self._context_id_start_timestamp:
self._context_id_start_timestamp = time.time()
self.init_word_timestamps()
frame = AudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._output_format["sample_rate"],
@@ -216,27 +200,6 @@ class CartesiaTTSService(TTSService):
except Exception as e:
logger.exception(f"{self} exception: {e}")
async def _context_appending_task_handler(self):
try:
while True:
await asyncio.sleep(0.1)
if not self._context_id_start_timestamp:
continue
elapsed_seconds = time.time() - self._context_id_start_timestamp
# Pop all words from self._timestamped_words_buffer that are
# older than the elapsed time and print a message about them to
# the console.
while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds:
word, timestamp = self._timestamped_words_buffer.pop(0)
if word == "LLMFullResponseEndFrame" and timestamp == 0:
await self.push_frame(LLMFullResponseEndFrame())
continue
await self.push_frame(TextFrame(word))
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")