[Add] Support for Cartesia AI STT (#1982)
This commit is contained in:
@@ -40,6 +40,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
### Added
|
||||
|
||||
- Added `CartesiaSTTService` which is a websocket based implementation to transcribe audio. Added a foundational example in `13f-cartesia-transcription.py`
|
||||
|
||||
- Added an `websocket` example, showing how to use the new Pipecat client
|
||||
`WebsocketTransport` to connect with Pipecat `FastAPIWebsocketTransport` or
|
||||
`WebsocketServerTransport`.
|
||||
|
||||
@@ -53,7 +53,7 @@ You can connect to Pipecat from any platform using our official SDKs:
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), Cartesia, [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
|
||||
71
examples/foundational/13f-cartesia-transcription.py
Normal file
71
examples/foundational/13f-cartesia-transcription.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import Frame, TranscriptionFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class TranscriptionLogger(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = CartesiaSTTService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
base_url=os.getenv("CARTESIA_BASE_URL"),
|
||||
)
|
||||
|
||||
tl = TranscriptionLogger()
|
||||
|
||||
pipeline = Pipeline([transport.input(), stt, tl])
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.examples.run import main
|
||||
|
||||
main(run_example, transport_params=transport_params)
|
||||
@@ -8,6 +8,7 @@ import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "cartesia", "cartesia.tts")
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "cartesia", "cartesia.[stt,tts]")
|
||||
|
||||
238
src/pipecat/services/cartesia/stt.py
Normal file
238
src/pipecat/services/cartesia/stt.py
Normal file
@@ -0,0 +1,238 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import urllib.parse
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
|
||||
class CartesiaLiveOptions:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "ink-whisper",
|
||||
language: str = Language.EN.value,
|
||||
encoding: str = "pcm_s16le",
|
||||
sample_rate: int = 16000,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.language = language
|
||||
self.encoding = encoding
|
||||
self.sample_rate = sample_rate
|
||||
self.additional_params = kwargs
|
||||
|
||||
def to_dict(self):
|
||||
params = {
|
||||
"model": self.model,
|
||||
"language": self.language if isinstance(self.language, str) else self.language.value,
|
||||
"encoding": self.encoding,
|
||||
"sample_rate": str(self.sample_rate),
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
def items(self):
|
||||
return self.to_dict().items()
|
||||
|
||||
def get(self, key, default=None):
|
||||
if hasattr(self, key):
|
||||
return getattr(self, key)
|
||||
return self.additional_params.get(key, default)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "CartesiaLiveOptions":
|
||||
return cls(**json.loads(json_str))
|
||||
|
||||
|
||||
class CartesiaSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = None,
|
||||
sample_rate: int = 16000,
|
||||
live_options: Optional[CartesiaLiveOptions] = None,
|
||||
**kwargs,
|
||||
):
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
default_options = CartesiaLiveOptions(
|
||||
model="ink-whisper",
|
||||
language=Language.EN.value,
|
||||
encoding="pcm_s16le",
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
|
||||
merged_options = default_options
|
||||
if live_options:
|
||||
merged_options_dict = default_options.to_dict()
|
||||
merged_options_dict.update(live_options.to_dict())
|
||||
merged_options = CartesiaLiveOptions(
|
||||
**{
|
||||
k: v
|
||||
for k, v in merged_options_dict.items()
|
||||
if not isinstance(v, str) or v != "None"
|
||||
}
|
||||
)
|
||||
|
||||
self._settings = merged_options
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url or "api.cartesia.ai"
|
||||
self._connection = None
|
||||
self._receiver_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
# If the connection is closed, due to timeout, we need to reconnect when the user starts speaking again
|
||||
if not self._connection or self._connection.closed:
|
||||
await self._connect()
|
||||
|
||||
await self._connection.send(audio)
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
params = self._settings.to_dict()
|
||||
ws_url = f"wss://{self._base_url}/stt/websocket?{urllib.parse.urlencode(params)}"
|
||||
logger.debug(f"Connecting to Cartesia: {ws_url}")
|
||||
headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key}
|
||||
|
||||
try:
|
||||
self._connection = await websockets.connect(ws_url, extra_headers=headers)
|
||||
# Setup the receiver task to handle the incoming messages from the Cartesia server
|
||||
if self._receiver_task is None or self._receiver_task.done():
|
||||
self._receiver_task = asyncio.create_task(self._receive_messages())
|
||||
logger.debug(f"Connected to Cartesia")
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: unable to connect to Cartesia: {e}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
try:
|
||||
while True:
|
||||
if not self._connection or self._connection.closed:
|
||||
break
|
||||
|
||||
message = await self._connection.recv()
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._process_response(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Received non-JSON message: {message}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.debug(f"WebSocket connection closed: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message receiver: {e}")
|
||||
|
||||
async def _process_response(self, data):
|
||||
if "type" in data:
|
||||
if data["type"] == "transcript":
|
||||
await self._on_transcript(data)
|
||||
|
||||
elif data["type"] == "error":
|
||||
logger.error(f"Cartesia error: {data.get('message', 'Unknown error')}")
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def _on_transcript(self, data):
|
||||
if "text" not in data:
|
||||
return
|
||||
|
||||
transcript = data.get("text", "")
|
||||
is_final = data.get("is_final", False)
|
||||
language = None
|
||||
|
||||
if "language" in data:
|
||||
try:
|
||||
language = Language(data["language"])
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
|
||||
if len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601(), language)
|
||||
)
|
||||
await self._handle_transcription(transcript, is_final, language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
# For interim transcriptions, just push the frame without tracing
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(transcript, "", time_now_iso8601(), language)
|
||||
)
|
||||
|
||||
async def _disconnect(self):
|
||||
if self._receiver_task:
|
||||
self._receiver_task.cancel()
|
||||
try:
|
||||
await self._receiver_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected exception while cancelling task: {e}")
|
||||
self._receiver_task = None
|
||||
|
||||
if self._connection and self._connection.open:
|
||||
logger.debug("Disconnecting from Cartesia")
|
||||
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
async def start_metrics(self):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.start_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# Send finalize command to flush the transcription session
|
||||
if self._connection and self._connection.open:
|
||||
await self._connection.send("finalize")
|
||||
Reference in New Issue
Block a user