From be2858bfbb3df8b814bd4d73ed7316e1f35e4aa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 14 Oct 2025 14:09:45 -0700 Subject: [PATCH] CartesiaSTTService: inherit from WebsocketSTTService --- CHANGELOG.md | 4 + src/pipecat/services/cartesia/stt.py | 145 ++++++++++++++------------- src/pipecat/services/cartesia/tts.py | 2 +- 3 files changed, 80 insertions(+), 71 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 424d6a420..dd3bfa3c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The runner `--folder` argument now supports downloading files from subdirectories. +### Changed + +- `CartesiaSTTService` now inherits from `WebsocketSTTService`. + ### Fixed - Fixed an issue where `RimeHttpTTSService` and `PiperTTSService` could generate diff --git a/src/pipecat/services/cartesia/stt.py b/src/pipecat/services/cartesia/stt.py index 5412c422c..97a4f7127 100644 --- a/src/pipecat/services/cartesia/stt.py +++ b/src/pipecat/services/cartesia/stt.py @@ -28,13 +28,12 @@ from pipecat.frames.frames import ( UserStoppedSpeakingFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.stt_service import STTService +from pipecat.services.stt_service import WebsocketSTTService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 from pipecat.utils.tracing.service_decorators import traced_stt try: - import websockets from websockets.asyncio.client import connect as websocket_connect from websockets.protocol import State except ModuleNotFoundError as e: @@ -124,7 +123,7 @@ class CartesiaLiveOptions: return cls(**json.loads(json_str)) -class CartesiaSTTService(STTService): +class CartesiaSTTService(WebsocketSTTService): """Speech-to-text service using Cartesia Live API. Provides real-time speech transcription through WebSocket connection @@ -176,8 +175,7 @@ class CartesiaSTTService(STTService): self.set_model_name(merged_options.model) self._api_key = api_key self._base_url = base_url or "api.cartesia.ai" - self._connection = None - self._receiver_task = None + self._receive_task = None def can_generate_metrics(self) -> bool: """Check if the service can generate processing metrics. @@ -214,6 +212,27 @@ class CartesiaSTTService(STTService): await super().cancel(frame) await self._disconnect() + async def start_metrics(self): + """Start performance metrics collection for transcription processing.""" + await self.start_ttfb_metrics() + await self.start_processing_metrics() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames and handle speech events. + + Args: + frame: The frame to process. + direction: Direction of frame flow in the pipeline. + """ + await super().process_frame(frame, direction) + + if isinstance(frame, UserStartedSpeakingFrame): + await self.start_metrics() + elif isinstance(frame, UserStoppedSpeakingFrame): + # Send finalize command to flush the transcription session + if self._websocket and self._websocket.state is State.OPEN: + await self._websocket.send("finalize") + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Process audio data for speech-to-text transcription. @@ -224,45 +243,69 @@ class CartesiaSTTService(STTService): None - transcription results are handled via WebSocket responses. """ # If the connection is closed, due to timeout, we need to reconnect when the user starts speaking again - if not self._connection or self._connection.state is State.CLOSED: + if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() - await self._connection.send(audio) + await self._websocket.send(audio) yield None async def _connect(self): - params = self._settings.to_dict() - ws_url = f"wss://{self._base_url}/stt/websocket?{urllib.parse.urlencode(params)}" - logger.debug(f"Connecting to Cartesia: {ws_url}") - headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key} + await self._connect_websocket() + if self._websocket and not self._receive_task: + self._receive_task = asyncio.create_task(self._receive_task_handler(self._report_error)) + + async def _disconnect(self): + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + await self._disconnect_websocket() + + async def _connect_websocket(self): try: - self._connection = await websocket_connect(ws_url, additional_headers=headers) - # Setup the receiver task to handle the incoming messages from the Cartesia server - if self._receiver_task is None or self._receiver_task.done(): - self._receiver_task = asyncio.create_task(self._receive_messages()) - logger.debug(f"Connected to Cartesia") + if self._websocket and self._websocket.state is State.OPEN: + return + logger.debug("Connecting to Cartesia STT") + + params = self._settings.to_dict() + ws_url = f"wss://{self._base_url}/stt/websocket?{urllib.parse.urlencode(params)}" + headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key} + + self._websocket = await websocket_connect(ws_url, additional_headers=headers) except Exception as e: logger.error(f"{self}: unable to connect to Cartesia: {e}") - async def _receive_messages(self): + async def _disconnect_websocket(self): try: - while True: - if not self._connection or self._connection.state is State.CLOSED: - break - - message = await self._connection.recv() - try: - data = json.loads(message) - await self._process_response(data) - except json.JSONDecodeError: - logger.warning(f"Received non-JSON message: {message}") - except asyncio.CancelledError: - pass - except websockets.exceptions.ConnectionClosed as e: - logger.debug(f"WebSocket connection closed: {e}") + if self._websocket and self._websocket.state is State.OPEN: + logger.debug("Disconnecting from Cartesia STT") + await self._websocket.close() except Exception as e: - logger.error(f"Error in message receiver: {e}") + logger.error(f"{self} error closing websocket: {e}") + finally: + self._websocket = None + + def _get_websocket(self): + if self._websocket: + return self._websocket + raise Exception("Websocket not connected") + + async def _process_messages(self): + async for message in self._get_websocket(): + try: + data = json.loads(message) + await self._process_response(data) + except json.JSONDecodeError: + logger.warning(f"Received non-JSON message: {message}") + + async def _receive_messages(self): + while True: + await self._process_messages() + # Cartesia times out after 5 minutes of innactivity (no keepalive + # mechanism is available). So, we try to reconnect. + logger.debug(f"{self} Cartesia connection was disconnected (timeout?), reconnecting") + await self._connect_websocket() async def _process_response(self, data): if "type" in data: @@ -316,41 +359,3 @@ class CartesiaSTTService(STTService): language, ) ) - - async def _disconnect(self): - if self._receiver_task: - self._receiver_task.cancel() - try: - await self._receiver_task - except asyncio.CancelledError: - pass - except Exception as e: - logger.exception(f"Unexpected exception while cancelling task: {e}") - self._receiver_task = None - - if self._connection and self._connection.state is State.OPEN: - logger.debug("Disconnecting from Cartesia") - - await self._connection.close() - self._connection = None - - async def start_metrics(self): - """Start performance metrics collection for transcription processing.""" - await self.start_ttfb_metrics() - await self.start_processing_metrics() - - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process incoming frames and handle speech events. - - Args: - frame: The frame to process. - direction: Direction of frame flow in the pipeline. - """ - await super().process_frame(frame, direction) - - if isinstance(frame, UserStartedSpeakingFrame): - await self.start_metrics() - elif isinstance(frame, UserStoppedSpeakingFrame): - # Send finalize command to flush the transcription session - if self._connection and self._connection.state is State.OPEN: - await self._connection.send("finalize") diff --git a/src/pipecat/services/cartesia/tts.py b/src/pipecat/services/cartesia/tts.py index 3b81da5d4..9b2475cb6 100644 --- a/src/pipecat/services/cartesia/tts.py +++ b/src/pipecat/services/cartesia/tts.py @@ -344,7 +344,7 @@ class CartesiaTTSService(AudioContextWordTTSService): try: if self._websocket and self._websocket.state is State.OPEN: return - logger.debug("Connecting to Cartesia") + logger.debug("Connecting to Cartesia TTS") self._websocket = await websocket_connect( f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}" )