services(cartesia): change to subclass of AsyncWordTTSService
This commit is contained in:
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.services.ai_services import AsyncWordTTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -60,7 +60,7 @@ def language_to_cartesia_language(language: Language) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class CartesiaTTSService(TTSService):
|
||||
class CartesiaTTSService(AsyncWordTTSService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -74,19 +74,17 @@ class CartesiaTTSService(TTSService):
|
||||
sample_rate: int = 16000,
|
||||
language: str = "en",
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Aggregating sentences still gives cleaner-sounding results and fewer
|
||||
# artifacts than streaming one word at a time. On average, waiting for
|
||||
# a full sentence should only "cost" us 15ms or so with GPT-4o or a Llama 3
|
||||
# model, and it's worth it for the better audio quality.
|
||||
self._aggregate_sentences = True
|
||||
|
||||
# we don't want to automatically push LLM response text frames, because the
|
||||
# context aggregators will add them to the LLM context even if we're
|
||||
# interrupted. cartesia gives us word-by-word timestamps. we can use those
|
||||
# to generate text frames ourselves aligned with the playout timing of the audio!
|
||||
self._push_text_frames = False
|
||||
# artifacts than streaming one word at a time. On average, waiting for a
|
||||
# full sentence should only "cost" us 15ms or so with GPT-4o or a Llama
|
||||
# 3 model, and it's worth it for the better audio quality.
|
||||
#
|
||||
# We also don't want to automatically push LLM response text frames,
|
||||
# because the context aggregators will add them to the LLM context even
|
||||
# if we're interrupted. Cartesia gives us word-by-word timestamps. We
|
||||
# can use those to generate text frames ourselves aligned with the
|
||||
# playout timing of the audio!
|
||||
super().__init__(aggregate_sentences=True, push_text_frames=False, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
@@ -102,10 +100,7 @@ class CartesiaTTSService(TTSService):
|
||||
|
||||
self._websocket = None
|
||||
self._context_id = None
|
||||
self._context_id_start_timestamp = None
|
||||
self._timestamped_words_buffer = []
|
||||
self._receive_task = None
|
||||
self._context_appending_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -140,7 +135,6 @@ class CartesiaTTSService(TTSService):
|
||||
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
|
||||
)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
self._context_appending_task = self.get_event_loop().create_task(self._context_appending_task_handler())
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -149,10 +143,6 @@ class CartesiaTTSService(TTSService):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._context_appending_task:
|
||||
self._context_appending_task.cancel()
|
||||
await self._context_appending_task
|
||||
self._context_appending_task = None
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
@@ -162,18 +152,14 @@ class CartesiaTTSService(TTSService):
|
||||
self._websocket = None
|
||||
|
||||
self._context_id = None
|
||||
self._context_id_start_timestamp = None
|
||||
self._timestamped_words_buffer = []
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error closing websocket: {e}")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._context_id = None
|
||||
self._context_id_start_timestamp = None
|
||||
self._timestamped_words_buffer = []
|
||||
await self.stop_all_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._context_id = None
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
@@ -188,16 +174,14 @@ class CartesiaTTSService(TTSService):
|
||||
# because we are likely still playing out audio and need the
|
||||
# timestamp to set send context frames.
|
||||
self._context_id = None
|
||||
self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0))
|
||||
await self.add_word_timestamps([("LLMFullResponseEndFrame", 0)])
|
||||
elif msg["type"] == "timestamps":
|
||||
# logger.debug(f"TIMESTAMPS: {msg}")
|
||||
self._timestamped_words_buffer.extend(
|
||||
await self.add_word_timestamps(
|
||||
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["end"]))
|
||||
)
|
||||
elif msg["type"] == "chunk":
|
||||
await self.stop_ttfb_metrics()
|
||||
if not self._context_id_start_timestamp:
|
||||
self._context_id_start_timestamp = time.time()
|
||||
self.init_word_timestamps()
|
||||
frame = AudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
@@ -216,27 +200,6 @@ class CartesiaTTSService(TTSService):
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
async def _context_appending_task_handler(self):
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
if not self._context_id_start_timestamp:
|
||||
continue
|
||||
elapsed_seconds = time.time() - self._context_id_start_timestamp
|
||||
# Pop all words from self._timestamped_words_buffer that are
|
||||
# older than the elapsed time and print a message about them to
|
||||
# the console.
|
||||
while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds:
|
||||
word, timestamp = self._timestamped_words_buffer.pop(0)
|
||||
if word == "LLMFullResponseEndFrame" and timestamp == 0:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
continue
|
||||
await self.push_frame(TextFrame(word))
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user