CartesiaTTSService: use AudioContextWordTTSService
By supporting multiple audio requests we fix an issue that was causing audio overlapping.
This commit is contained in:
@@ -105,6 +105,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `CartesiaTTSService` service issue that would cause audio overlapping
|
||||
in some cases.
|
||||
|
||||
- Fixed a websocket-based service issue (e.g. `CartesiaTTSService`) that was
|
||||
preventing a reconnection after the server disconnected cleanly, which was
|
||||
causing an inifite loop instead.
|
||||
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService, WordTTSService
|
||||
from pipecat.services.ai_services import AudioContextWordTTSService, TTSService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
@@ -75,7 +75,7 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
|
||||
return result
|
||||
|
||||
|
||||
class CartesiaTTSService(WordTTSService, WebsocketService):
|
||||
class CartesiaTTSService(AudioContextWordTTSService, WebsocketService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[Union[str, float]] = ""
|
||||
@@ -105,7 +105,7 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
|
||||
# if we're interrupted. Cartesia gives us word-by-word timestamps. We
|
||||
# can use those to generate text frames ourselves aligned with the
|
||||
# playout timing of the audio!
|
||||
WordTTSService.__init__(
|
||||
AudioContextWordTTSService.__init__(
|
||||
self,
|
||||
aggregate_sentences=True,
|
||||
push_text_frames=False,
|
||||
@@ -191,12 +191,12 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
logger.debug("Connecting to Cartesia")
|
||||
@@ -239,21 +239,19 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
msg = self._build_msg(text="", continue_transcript=False)
|
||||
await self._websocket.send(msg)
|
||||
self._context_id = None
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
if not msg or msg["context_id"] != self._context_id:
|
||||
if not msg or not self.audio_context_available(msg["context_id"]):
|
||||
continue
|
||||
if msg["type"] == "done":
|
||||
await self.stop_ttfb_metrics()
|
||||
# Unset _context_id but not the _context_id_start_timestamp
|
||||
# because we are likely still playing out audio and need the
|
||||
# timestamp to set send context frames.
|
||||
self._context_id = None
|
||||
await self.add_word_timestamps(
|
||||
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
|
||||
)
|
||||
await self.remove_audio_context(msg["context_id"])
|
||||
elif msg["type"] == "timestamps":
|
||||
await self.add_word_timestamps(
|
||||
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]))
|
||||
@@ -266,12 +264,13 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await self.append_to_audio_context(msg["context_id"], frame)
|
||||
elif msg["type"] == "error":
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['error']}"))
|
||||
self._context_id = None
|
||||
else:
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
|
||||
@@ -299,6 +298,7 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._context_id = str(uuid.uuid4())
|
||||
await self.create_audio_context(self._context_id)
|
||||
|
||||
msg = self._build_msg(text=text or " ") # Text must contain at least one character
|
||||
|
||||
|
||||
Reference in New Issue
Block a user