diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 429e7b3e4..f263db60d 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -27,7 +27,7 @@ from pipecat.frames.frames import ( ) from pipecat.processors.frame_processor import FrameDirection from pipecat.transcriptions.language import Language -from pipecat.services.ai_services import TTSService +from pipecat.services.ai_services import AsyncWordTTSService from loguru import logger @@ -60,7 +60,7 @@ def language_to_cartesia_language(language: Language) -> str | None: return None -class CartesiaTTSService(TTSService): +class CartesiaTTSService(AsyncWordTTSService): def __init__( self, @@ -74,19 +74,17 @@ class CartesiaTTSService(TTSService): sample_rate: int = 16000, language: str = "en", **kwargs): - super().__init__(**kwargs) - # Aggregating sentences still gives cleaner-sounding results and fewer - # artifacts than streaming one word at a time. On average, waiting for - # a full sentence should only "cost" us 15ms or so with GPT-4o or a Llama 3 - # model, and it's worth it for the better audio quality. - self._aggregate_sentences = True - - # we don't want to automatically push LLM response text frames, because the - # context aggregators will add them to the LLM context even if we're - # interrupted. cartesia gives us word-by-word timestamps. we can use those - # to generate text frames ourselves aligned with the playout timing of the audio! - self._push_text_frames = False + # artifacts than streaming one word at a time. On average, waiting for a + # full sentence should only "cost" us 15ms or so with GPT-4o or a Llama + # 3 model, and it's worth it for the better audio quality. + # + # We also don't want to automatically push LLM response text frames, + # because the context aggregators will add them to the LLM context even + # if we're interrupted. Cartesia gives us word-by-word timestamps. We + # can use those to generate text frames ourselves aligned with the + # playout timing of the audio! + super().__init__(aggregate_sentences=True, push_text_frames=False, **kwargs) self._api_key = api_key self._cartesia_version = cartesia_version @@ -102,10 +100,7 @@ class CartesiaTTSService(TTSService): self._websocket = None self._context_id = None - self._context_id_start_timestamp = None - self._timestamped_words_buffer = [] self._receive_task = None - self._context_appending_task = None def can_generate_metrics(self) -> bool: return True @@ -140,7 +135,6 @@ class CartesiaTTSService(TTSService): f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}" ) self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) - self._context_appending_task = self.get_event_loop().create_task(self._context_appending_task_handler()) except Exception as e: logger.exception(f"{self} initialization error: {e}") self._websocket = None @@ -149,10 +143,6 @@ class CartesiaTTSService(TTSService): try: await self.stop_all_metrics() - if self._context_appending_task: - self._context_appending_task.cancel() - await self._context_appending_task - self._context_appending_task = None if self._receive_task: self._receive_task.cancel() await self._receive_task @@ -162,18 +152,14 @@ class CartesiaTTSService(TTSService): self._websocket = None self._context_id = None - self._context_id_start_timestamp = None - self._timestamped_words_buffer = [] except Exception as e: logger.exception(f"{self} error closing websocket: {e}") async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): await super()._handle_interruption(frame, direction) - self._context_id = None - self._context_id_start_timestamp = None - self._timestamped_words_buffer = [] await self.stop_all_metrics() await self.push_frame(LLMFullResponseEndFrame()) + self._context_id = None async def _receive_task_handler(self): try: @@ -188,16 +174,14 @@ class CartesiaTTSService(TTSService): # because we are likely still playing out audio and need the # timestamp to set send context frames. self._context_id = None - self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0)) + await self.add_word_timestamps([("LLMFullResponseEndFrame", 0)]) elif msg["type"] == "timestamps": - # logger.debug(f"TIMESTAMPS: {msg}") - self._timestamped_words_buffer.extend( + await self.add_word_timestamps( list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["end"])) ) elif msg["type"] == "chunk": await self.stop_ttfb_metrics() - if not self._context_id_start_timestamp: - self._context_id_start_timestamp = time.time() + self.init_word_timestamps() frame = AudioRawFrame( audio=base64.b64decode(msg["data"]), sample_rate=self._output_format["sample_rate"], @@ -216,27 +200,6 @@ class CartesiaTTSService(TTSService): except Exception as e: logger.exception(f"{self} exception: {e}") - async def _context_appending_task_handler(self): - try: - while True: - await asyncio.sleep(0.1) - if not self._context_id_start_timestamp: - continue - elapsed_seconds = time.time() - self._context_id_start_timestamp - # Pop all words from self._timestamped_words_buffer that are - # older than the elapsed time and print a message about them to - # the console. - while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds: - word, timestamp = self._timestamped_words_buffer.pop(0) - if word == "LLMFullResponseEndFrame" and timestamp == 0: - await self.push_frame(LLMFullResponseEndFrame()) - continue - await self.push_frame(TextFrame(word)) - except asyncio.CancelledError: - pass - except Exception as e: - logger.exception(f"{self} exception: {e}") - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]")